Commit a8d19fd9 authored by yuguo's avatar yuguo
Browse files
parents 9d0f1c9b 6dfe66e9
......@@ -1573,6 +1573,10 @@ def test_grouped_linear_accuracy(
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
# Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
......@@ -2087,7 +2091,9 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention"]:
pytest.skip("Not support FusedAttention")
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
......@@ -2268,6 +2274,10 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
grad = True
single_output = False
# Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
for i in range(z):
general_gemm(
A[i],
......
......@@ -1459,7 +1459,7 @@ void rocblas_gemm(const Tensor *inputA,
// extract the stream order alloc env
bool stream_order_alloc = false;
if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC") ) {
if (env_p != nullptr && std::string(env_p) == "1")
if (env_p == nullptr || std::string(env_p) == "1")
stream_order_alloc = true;
}
......@@ -1467,6 +1467,7 @@ void rocblas_gemm(const Tensor *inputA,
NVTE_CHECK((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f32_r && B_type==rocblas_datatype_f32_r && D_type==rocblas_datatype_f32_r) ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment