Commit 7a47930f authored by wenjh's avatar wenjh
Browse files

[ROCBLAS_GEMM] Default use of hipMallocAsync



Default use of hipMallocAsync rather than hipMalloc in rocblas_gemm and
add support of fp16_fp16_fp32 in rocblas_gemm.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent e8f92b93
...@@ -2087,7 +2087,9 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -2087,7 +2087,9 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("is_paged", [False, True]) @pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged): def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states() reset_rng_states()
if backend in ["FusedAttention"]:
pytest.skip("Not support FusedAttention")
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32") pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE: if use_RoPE:
......
...@@ -1459,7 +1459,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1459,7 +1459,7 @@ void rocblas_gemm(const Tensor *inputA,
// extract the stream order alloc env // extract the stream order alloc env
bool stream_order_alloc = false; bool stream_order_alloc = false;
if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC") ) { if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC") ) {
if (env_p != nullptr && std::string(env_p) == "1") if (env_p == nullptr || std::string(env_p) == "1")
stream_order_alloc = true; stream_order_alloc = true;
} }
...@@ -1467,6 +1467,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1467,6 +1467,7 @@ void rocblas_gemm(const Tensor *inputA,
NVTE_CHECK((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r) || NVTE_CHECK((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r) || (A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_f32_r) || (A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f32_r && B_type==rocblas_datatype_f32_r && D_type==rocblas_datatype_f32_r) || (A_type==rocblas_datatype_f32_r && B_type==rocblas_datatype_f32_r && D_type==rocblas_datatype_f32_r) ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment