Commit e42f9ecf authored by illsilin's avatar illsilin
Browse files

enable fp8 gemm_xdl for all gfx9 targets

parent 68459244
...@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) ...@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_gemm_wmma) add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
...@@ -66,13 +66,11 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -66,13 +66,11 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif() endif()
endforeach() endforeach()
if(GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
endif()
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif() endif()
...@@ -279,8 +279,7 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) ...@@ -279,8 +279,7 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch(conv_param.num_dim_spatial_) switch(conv_param.num_dim_spatial_)
{ {
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2: case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
} }
......
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
......
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_fpAintB_gemm_wmma) add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
......
...@@ -56,8 +56,7 @@ __global__ void ...@@ -56,8 +56,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -162,7 +161,7 @@ __global__ void ...@@ -162,7 +161,7 @@ __global__ void
ignore = G1; ignore = G1;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Self-Attention // Self-Attention
...@@ -188,8 +187,7 @@ __global__ void ...@@ -188,8 +187,7 @@ __global__ void
index_t head_size, index_t head_size,
float alpha) float alpha)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -294,7 +292,7 @@ __global__ void ...@@ -294,7 +292,7 @@ __global__ void
ignore = head_count; ignore = head_count;
ignore = head_size; ignore = head_size;
ignore = alpha; ignore = alpha;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Cross-Attention // Cross-Attention
// Self-Attention // Self-Attention
...@@ -323,8 +321,7 @@ __global__ void ...@@ -323,8 +321,7 @@ __global__ void
index_t head_size, index_t head_size,
float alpha) float alpha)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -435,7 +432,7 @@ __global__ void ...@@ -435,7 +432,7 @@ __global__ void
ignore = head_count; ignore = head_count;
ignore = head_size; ignore = head_size;
ignore = alpha; ignore = alpha;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// MN = MK * KL * LN // MN = MK * KL * LN
...@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>)) is_same_v<AccDataType, int32_t>))
......
...@@ -61,8 +61,7 @@ __global__ void ...@@ -61,8 +61,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -169,7 +168,7 @@ __global__ void ...@@ -169,7 +168,7 @@ __global__ void
ignore = G1; ignore = G1;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
...@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -60,8 +60,7 @@ __global__ void ...@@ -60,8 +60,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -168,7 +167,7 @@ __global__ void ...@@ -168,7 +167,7 @@ __global__ void
ignore = G1; ignore = G1;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
...@@ -595,8 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma ...@@ -595,8 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -952,8 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma ...@@ -952,8 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
......
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment