Commit e42f9ecf authored by illsilin's avatar illsilin
Browse files

enable fp8 gemm_xdl for all gfx9 targets

parent 68459244
......@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
......@@ -66,13 +66,11 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()
if(GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
endif()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
......@@ -279,8 +279,7 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch(conv_param.num_dim_spatial_)
{
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2:
return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
}
......
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
......
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
......
......@@ -56,8 +56,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
......@@ -162,7 +161,7 @@ __global__ void
ignore = G1;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx11__))
}
// Self-Attention
......@@ -188,8 +187,7 @@ __global__ void
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
......@@ -294,7 +292,7 @@ __global__ void
ignore = head_count;
ignore = head_size;
ignore = alpha;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx11__))
}
// Cross-Attention
// Self-Attention
......@@ -323,8 +321,7 @@ __global__ void
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
......@@ -435,7 +432,7 @@ __global__ void
ignore = head_count;
ignore = head_size;
ignore = alpha;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx11__))
}
// Computes C = A * B0 * B1
// MN = MK * KL * LN
......@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
......@@ -61,8 +61,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
......@@ -169,7 +168,7 @@ __global__ void
ignore = G1;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx11__))
}
// Computes C = A * B0 * B1
......@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -60,8 +60,7 @@ __global__ void
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
......@@ -168,7 +167,7 @@ __global__ void
ignore = G1;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx11__))
}
// Computes C = A * B0 * B1
......@@ -595,8 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -952,8 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
......
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment