Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e42f9ecf
Commit
e42f9ecf
authored
Mar 07, 2024
by
illsilin
Browse files
enable fp8 gemm_xdl for all gfx9 targets
parent
68459244
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
31 additions
and
46 deletions
+31
-46
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+7
-9
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
+1
-1
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+2
-3
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-1
example/64_fpAintB_gemm/CMakeLists.txt
example/64_fpAintB_gemm/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+8
-13
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
...vice/impl/device_grouped_query_attention_forward_wmma.hpp
+4
-7
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
...device/impl/device_multi_query_attention_forward_wmma.hpp
+4
-7
test/grouped_convnd_bwd_data/CMakeLists.txt
test/grouped_convnd_bwd_data/CMakeLists.txt
+1
-1
test/grouped_convnd_bwd_weight/CMakeLists.txt
test/grouped_convnd_bwd_weight/CMakeLists.txt
+1
-1
No files found.
example/01_gemm/CMakeLists.txt
View file @
e42f9ecf
...
@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
...
@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx11
00"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_custom_target
(
example_gemm_wmma
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
...
@@ -66,13 +66,11 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -66,13 +66,11 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif
()
endif
()
endforeach
()
endforeach
()
if
(
GPU_TARGETS MATCHES
"gfx941"
OR GPU_TARGETS MATCHES
"gfx942"
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
endif
()
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
View file @
e42f9ecf
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx11
00"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
endif
()
endif
()
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
e42f9ecf
...
@@ -279,9 +279,8 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
...
@@ -279,9 +279,8 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch
(
conv_param
.
num_dim_spatial_
)
switch
(
conv_param
.
num_dim_spatial_
)
{
{
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case
2
:
case
2
:
return
run_grouped_conv_fwd_bias_relu_add
<
2
>
(
config
,
conv_param
);
return
run_grouped_conv_fwd_bias_relu_add
<
2
>
(
config
,
conv_param
);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
}
}
return
false
;
return
false
;
...
...
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
e42f9ecf
if
(
GPU_TARGETS MATCHES
"gfx11
00"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp
)
add_example_executable
(
example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp
)
...
...
example/64_fpAintB_gemm/CMakeLists.txt
View file @
e42f9ecf
if
(
GPU_TARGETS MATCHES
"gfx11
00"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
add_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
e42f9ecf
...
@@ -56,8 +56,7 @@ __global__ void
...
@@ -56,8 +56,7 @@ __global__ void
bool
input_permute
,
bool
input_permute
,
bool
output_permute
)
bool
output_permute
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -162,7 +161,7 @@ __global__ void
...
@@ -162,7 +161,7 @@ __global__ void
ignore
=
G1
;
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
input_permute
;
ignore
=
output_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Self-Attention
// Self-Attention
...
@@ -188,8 +187,7 @@ __global__ void
...
@@ -188,8 +187,7 @@ __global__ void
index_t
head_size
,
index_t
head_size
,
float
alpha
)
float
alpha
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -294,7 +292,7 @@ __global__ void
...
@@ -294,7 +292,7 @@ __global__ void
ignore
=
head_count
;
ignore
=
head_count
;
ignore
=
head_size
;
ignore
=
head_size
;
ignore
=
alpha
;
ignore
=
alpha
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Cross-Attention
// Cross-Attention
// Self-Attention
// Self-Attention
...
@@ -323,8 +321,7 @@ __global__ void
...
@@ -323,8 +321,7 @@ __global__ void
index_t
head_size
,
index_t
head_size
,
float
alpha
)
float
alpha
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -435,7 +432,7 @@ __global__ void
...
@@ -435,7 +432,7 @@ __global__ void
ignore
=
head_count
;
ignore
=
head_count
;
ignore
=
head_size
;
ignore
=
head_size
;
ignore
=
alpha
;
ignore
=
alpha
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// MN = MK * KL * LN
...
@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0
#if 0
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
e42f9ecf
...
@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
View file @
e42f9ecf
...
@@ -61,8 +61,7 @@ __global__ void
...
@@ -61,8 +61,7 @@ __global__ void
bool
input_permute
,
bool
input_permute
,
bool
output_permute
)
bool
output_permute
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -169,7 +168,7 @@ __global__ void
...
@@ -169,7 +168,7 @@ __global__ void
ignore
=
G1
;
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
input_permute
;
ignore
=
output_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
...
@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
...
@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
...
@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0
#if 0
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
View file @
e42f9ecf
...
@@ -60,8 +60,7 @@ __global__ void
...
@@ -60,8 +60,7 @@ __global__ void
bool
input_permute
,
bool
input_permute
,
bool
output_permute
)
bool
output_permute
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -168,7 +167,7 @@ __global__ void
...
@@ -168,7 +167,7 @@ __global__ void
ignore
=
G1
;
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
input_permute
;
ignore
=
output_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
...
@@ -595,8 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
...
@@ -595,8 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
@@ -952,8 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
...
@@ -952,8 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0
#if 0
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
{
...
...
test/grouped_convnd_bwd_data/CMakeLists.txt
View file @
e42f9ecf
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940
)
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
gfx1103
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list_xdl AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list_xdl AND target EQUAL 0
)
...
...
test/grouped_convnd_bwd_weight/CMakeLists.txt
View file @
e42f9ecf
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
gfx1103
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment