Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
06701e70
Unverified
Commit
06701e70
authored
Jul 09, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Jul 09, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
5800d24e
da42a889
Changes
156
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
689 additions
and
31 deletions
+689
-31
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+1
-1
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+2
-2
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
...d_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
+12
-1
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
...d_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
...m_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
..._scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+15
-4
example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp
...oftmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
..._softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
...m_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+15
-4
example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp
...v_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp
+12
-1
example/CMakeLists.txt
example/CMakeLists.txt
+3
-3
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+3
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+3
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+3
-0
include/ck/ck.hpp
include/ck/ck.hpp
+5
-9
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+5
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+499
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+7
-0
include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
...ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
+44
-0
No files found.
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
06701e70
...
@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
06701e70
...
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
...
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
2
,
2
,
4
,
4
,
4
,
4
,
tru
e
,
fals
e
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
4
,
4
,
4
,
tru
e
,
fals
e
,
1
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
S
<
1
,
64
,
1
,
2
>
,
...
...
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
View file @
06701e70
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
#include "common_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
// kernel data types
// kernel data types
using
InKernelDataType
=
FP16
;
using
InKernelDataType
=
FP16
;
...
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
...
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
View file @
06701e70
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
#include "common_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
// kernel data types
// kernel data types
using
InKernelDataType
=
I8
;
using
InKernelDataType
=
I8
;
...
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
...
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
06701e70
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
06701e70
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
View file @
06701e70
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
//
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
using
DeviceMHAFactory
=
std
::
tuple
<
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
#ifdef CK_MHA_USE_WAVE_1
...
@@ -277,10 +278,10 @@ using DeviceMHAFactory =
...
@@ -277,10 +278,10 @@ using DeviceMHAFactory =
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
#endif
#endif
#ifdef CK_MHA_USE_WAVE_8
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
...
@@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_cross_attention_wmma.inc"
#include "run_cross_attention_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp
View file @
06701e70
...
@@ -28,6 +28,7 @@ Example is GQA-4
...
@@ -28,6 +28,7 @@ Example is GQA-4
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -299,4 +300,14 @@ using ReferenceGemm1Instance =
...
@@ -299,4 +300,14 @@ using ReferenceGemm1Instance =
#include "run_grouped_query_attention_forward_wmma.inc"
#include "run_grouped_query_attention_forward_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
View file @
06701e70
...
@@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.”
...
@@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.”
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_
...
@@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_
#include "run_multi_query_attention_forward_wmma.inc"
#include "run_multi_query_attention_forward_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
View file @
06701e70
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
...
@@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
//
#define CK_MHA_USE_WAVE_8
using
DeviceMHAFactory
=
using
DeviceMHAFactory
=
std
::
tuple
<
std
::
tuple
<
#ifdef CK_MHA_USE_WAVE_1
#ifdef CK_MHA_USE_WAVE_1
...
@@ -277,10 +278,10 @@ using DeviceMHAFactory =
...
@@ -277,10 +278,10 @@ using DeviceMHAFactory =
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
#endif
#endif
#ifdef CK_MHA_USE_WAVE_8
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
...
@@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_self_attention_wmma.inc"
#include "run_self_attention_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp
View file @
06701e70
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#include "common.hpp"
#include "common.hpp"
#include "ck/host_utility/device_prop.hpp"
using
OutDataType
=
FP16
;
using
OutDataType
=
FP16
;
using
WeiDataType
=
FP16
;
using
WeiDataType
=
FP16
;
...
@@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
...
@@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
#include "run_grouped_conv_bwd_data_example.inc"
#include "run_grouped_conv_bwd_data_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_bwd_data_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run_grouped_conv_bwd_data_example
(
argc
,
argv
);
}
example/CMakeLists.txt
View file @
06701e70
...
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT EX_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"_wmma"
)
if
(
NOT EX_TARGETS MATCHES
"gfx11"
AND
NOT EX_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
...
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT EX_TARGETS MATCHES
"gfx1
1
"
AND source MATCHES
"_wmma"
)
if
(
NOT EX_TARGETS MATCHES
"gfx11"
AND
NOT EX_TARGETS MATCHES
"gfx1
2
"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
...
@@ -181,7 +181,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
...
@@ -181,7 +181,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir
# add all example subdir
file
(
GLOB dir_list LIST_DIRECTORIES true *
)
file
(
GLOB dir_list LIST_DIRECTORIES true *
)
FOREACH
(
subdir
${
dir_list
}
)
FOREACH
(
subdir
${
dir_list
}
)
IF
(
IS_DIRECTORY
"
${
subdir
}
"
)
if
(
IS_DIRECTORY
"
${
subdir
}
"
AND EXISTS
"
${
subdir
}
/CMakeLists.txt"
)
add_subdirectory
(
${
subdir
}
)
add_subdirectory
(
${
subdir
}
)
ENDIF
()
ENDIF
()
ENDFOREACH
()
ENDFOREACH
()
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
06701e70
...
@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
...
@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
per_hdim_case
=
per_hdim_case
+
FMHA_BWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_BWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_BWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
per_dtypes
=
per_dtypes
+
FMHA_BWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
if
not
per_dtypes
:
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
)
return
FMHA_BWD_KERNEL_HEADER
+
FMHA_BWD_API
.
format
(
F_dispatch
=
per_dtypes
)
# GEMM0: Q@K=S^T
# GEMM0: Q@K=S^T
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
View file @
06701e70
...
@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
...
@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
if
not
per_dtypes
:
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_API
.
format
(
F_dispatch
=
per_dtypes
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
@
dataclass
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
06701e70
...
@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
...
@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
if
not
per_dtypes
:
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_SPLITKV_API
.
format
(
F_dispatch
=
per_dtypes
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_SPLITKV_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
@
dataclass
...
...
include/ck/ck.hpp
View file @
06701e70
...
@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#define __gfx11__
#endif
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// buffer resource
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
...
@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__)
#elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__)
#elif defined(__gfx11__)
|| defined(__gfx12__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#endif
...
@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx11__)
#elif defined(__gfx11__)
|| defined(__gfx12__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
...
@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define CK_USE_AMD_MFMA_GFX940
#define CK_USE_AMD_MFMA_GFX940
#endif
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx11__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
#define CK_USE_AMD_BUFFER_LOAD 1
...
...
include/ck/host_utility/device_prop.hpp
View file @
06701e70
...
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
...
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
}
}
inline
bool
is_gfx12_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1200"
||
ck
::
get_device_name
()
==
"gfx1201"
;
}
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
06701e70
...
@@ -13,6 +13,504 @@
...
@@ -13,6 +13,504 @@
namespace
ck
{
namespace
ck
{
#ifdef __gfx12__
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct
BlockwiseGemmWMMA
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
wmma_gemm
.
GetRegSizePerWmma
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
if
constexpr
(
AEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_a_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
if
constexpr
(
BEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_b_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
();
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperWMMA_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperWMMA_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperWMMA_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
waveId_m
,
blk_idx
[
I0
],
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
using
Tuple6
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA
(
Tuple6
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple6
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
));
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_assert
(
KPack
%
(
A_K1
*
A_KRow
)
==
0
,
""
);
static_assert
(
KPack
%
(
B_K1
*
B_KRow
)
==
0
,
""
);
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
static_for
<
0
,
KPerBlock
/
KPack
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatA
,
KPack
/
A_KRow
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
/
B_KRow
>
b_thread_vec
;
static_for
<
0
,
KPack
/
A_KRow
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
});
static_for
<
0
,
KPack
/
B_KRow
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
/
A_KRow
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
/
B_KRow
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
else
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
KPack
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
FloatA
,
KPack
/
A_KRow
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
/
B_KRow
>
b_thread_vec
;
static_for
<
0
,
KPack
/
A_KRow
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
});
static_for
<
0
,
KPack
/
B_KRow
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
/
A_KRow
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
/
B_KRow
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
}
protected:
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
>
{},
Number
<
KPack
/
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
>
{},
Number
<
KPack
/
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
};
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
false
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
false
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
#else
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
...
@@ -527,5 +1025,6 @@ struct BlockwiseGemmWMMA
...
@@ -527,5 +1025,6 @@ struct BlockwiseGemmWMMA
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
};
#endif
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
06701e70
...
@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
{
#ifdef __gfx12__
asm
volatile
(
"\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"s_barrier"
::
);
asm
volatile
(
"s_barrier"
::
);
#endif
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
...
...
include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp
0 → 100644
View file @
06701e70
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm_Streamk_V2
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
Streamk_sel
,
ck
::
index_t
Grid_size
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment