Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ccaea50e
Commit
ccaea50e
authored
Mar 08, 2024
by
Jing Zhang
Browse files
merge navi31_rel
parents
0b914465
10127959
Changes
126
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
614 additions
and
294 deletions
+614
-294
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+15
-2
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+1
-1
example/01_gemm/gemm_xdl_fp8.cpp
example/01_gemm/gemm_xdl_fp8.cpp
+9
-5
example/01_gemm/gemm_xdl_fp8_bf8.cpp
example/01_gemm/gemm_xdl_fp8_bf8.cpp
+4
-4
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+4
-3
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
+1
-1
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+2
-3
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-1
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
..._scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+22
-0
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
...gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
...ched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
+112
-72
example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc
...softmax_gemm/run_grouped_query_attention_forward_wmma.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc
...e_softmax_gemm/run_multi_query_attention_forward_wmma.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
...tched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
+102
-69
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
...m_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+84
-40
example/64_fpAintB_gemm/CMakeLists.txt
example/64_fpAintB_gemm/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+8
-13
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+237
-64
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
...vice/impl/device_grouped_query_attention_forward_wmma.hpp
+4
-7
No files found.
example/01_gemm/CMakeLists.txt
View file @
ccaea50e
...
@@ -72,5 +72,18 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -72,5 +72,18 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif
()
endif
()
endforeach
()
endforeach
()
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8 gemm_xdl_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
list
(
APPEND gpu_list gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
set
(
target 1
)
endif
()
endforeach
()
example/01_gemm/common.hpp
View file @
ccaea50e
...
@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final
...
@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final
struct
ExecutionConfig
final
struct
ExecutionConfig
final
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
};
};
...
...
example/01_gemm/gemm_xdl_fp8.cpp
View file @
ccaea50e
...
@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
...
@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
LoopSched
=
ck
::
make_default_loop_scheduler
();
static
constexpr
auto
PipelineVer
=
ck
::
PipelineVersion
::
v1
;
using
ComputeTypeA
=
ck
::
f8_t
;
using
ComputeTypeB
=
ck
::
f8_t
;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/gemm_xdl_fp8_bf8.cpp
View file @
ccaea50e
...
@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
...
@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
// clang-format on
// clang-format on
...
...
example/01_gemm/run_gemm_example.inc
View file @
ccaea50e
...
@@ -85,8 +85,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -85,8 +85,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
2.
f
,
2.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
2.
f
,
2.
f
}(
b_k_n
);
break
;
break
;
default
:
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
0.1
f
,
0.1
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
0.1
f
,
0.1
f
}(
b_k_n
);
}
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
...
@@ -256,7 +256,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -256,7 +256,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
1
e
-
1
,
1
e
-
1
);
#endif
#endif
}
}
...
...
example/29_batched_gemm_bias_e_permute/CMakeLists.txt
View file @
ccaea50e
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx11
00"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp
)
endif
()
endif
()
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
ccaea50e
...
@@ -279,9 +279,8 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
...
@@ -279,9 +279,8 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch
(
conv_param
.
num_dim_spatial_
)
switch
(
conv_param
.
num_dim_spatial_
)
{
{
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case
2
:
case
2
:
return
run_grouped_conv_fwd_bias_relu_add
<
2
>
(
config
,
conv_param
);
return
run_grouped_conv_fwd_bias_relu_add
<
2
>
(
config
,
conv_param
);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
}
}
return
false
;
return
false
;
...
...
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
ccaea50e
if
(
GPU_TARGETS MATCHES
"gfx11
00"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp
)
add_example_executable
(
example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp
)
...
...
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
View file @
ccaea50e
...
@@ -301,6 +301,28 @@ using DeviceMHAFactory =
...
@@ -301,6 +301,28 @@ using DeviceMHAFactory =
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
// Gemm 0
128
,
64
,
48
,
8
,
4
,
// Gemm 1
48
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
MaskingSpec
>
#endif
#endif
>
;
>
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
View file @
ccaea50e
...
@@ -182,9 +182,9 @@ int run(int argc, char* argv[])
...
@@ -182,9 +182,9 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
...
...
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
View file @
ccaea50e
...
@@ -9,20 +9,18 @@ int run(int argc, char* argv[])
...
@@ -9,20 +9,18 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
256
;
ck
::
index_t
q_sequence_length
=
256
;
ck
::
index_t
N
=
64
;
ck
::
index_t
kv_sequence_length
=
64
;
ck
::
index_t
K
=
80
;
ck
::
index_t
head_dim
=
80
;
ck
::
index_t
O
=
80
;
// Output shape C[batch_size, q_sequence_length, head_num, head_dim]. Batch dim, outer dim,
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// inner dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o =
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// permute(C_g0_g1_m_o, [0, 2, 1, 3])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
batch_size
=
2
;
ck
::
index_t
G0
=
2
;
ck
::
index_t
head_num
=
8
;
ck
::
index_t
G1
=
8
;
float
alpha
=
1
;
float
alpha
=
1
;
bool
input_permute
=
true
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
if
(
argc
==
1
)
if
(
argc
==
1
)
...
@@ -35,58 +33,85 @@ int run(int argc, char* argv[])
...
@@ -35,58 +33,85 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
0
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
q_sequence_length
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
kv_sequence_length
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
head_dim
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
batch_size
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
head_num
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
alpha
=
std
::
stof
(
argv
[
9
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
printf
(
"arg10: scale (alpha)
\n
"
);
"arg4 to 8: q_sequence_length, kv_sequence_length, head_dim, batch_size, head_num
\n
"
);
printf
(
"arg
11 to 12: input / output permute
\n
"
);
printf
(
"arg
9: scale (alpha)
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
batch_size
,
head_num
,
q_sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
q_sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
head_num
*
head_dim
,
1
}
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
// A layout [batch_size, q_sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
q_sequence_length
*
head_dim
,
q_sequence_length
*
head_dim
,
head_dim
,
1
};
// A layout [batch_size, head_num, q_sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
batch_size
,
head_num
,
kv_sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
kv_sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
head_num
*
head_dim
,
1
}
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
// B0 layout [batch_size, kv_sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
kv_sequence_length
*
head_dim
,
kv_sequence_length
*
head_dim
,
head_dim
,
1
};
// B0 layout [batch_size, head_num, kv_sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
batch_size
,
head_num
,
head_dim
,
kv_sequence_length
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
kv_sequence_length
*
head_num
*
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
head_dim
,
1
,
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
head_num
*
head_dim
}
// B1 layout [batch_size, kv_sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
kv_sequence_length
*
head_dim
,
kv_sequence_length
*
head_dim
,
1
,
head_dim
};
// B1 layout [batch_size, head_num, kv_sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
batch_size
,
head_num
,
q_sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
q_sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
head_num
*
head_dim
,
1
}
// C layout [batch_size, q_sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
q_sequence_length
*
head_dim
,
q_sequence_length
*
head_dim
,
head_dim
,
1
};
// C layout [batch_size, head_num, q_sequence_length, head_dim]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
@@ -158,9 +183,14 @@ int run(int argc, char* argv[])
...
@@ -158,9 +183,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
2
,
K
};
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_lengths
{
batch_size
,
head_num
,
kv_sequence_length
,
2
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
std
::
vector
<
ck
::
index_t
>
kv_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
2
*
K
,
2
*
K
,
G1
*
2
*
K
,
K
,
1
};
// kv layout [G0, M, G1, 2, K]
kv_sequence_length
*
head_num
*
2
*
head_dim
,
2
*
head_dim
,
head_num
*
2
*
head_dim
,
head_dim
,
1
};
// kv layout [batch_size, q_sequence_length, head_num, 2, head_dim]
Tensor
<
ADataType
>
kv_gs_ns_ks
(
kv_gs_ns_ks_lengths
,
kv_gs_ns_ks_strides
);
Tensor
<
ADataType
>
kv_gs_ns_ks
(
kv_gs_ns_ks_lengths
,
kv_gs_ns_ks_strides
);
// merge kv into a packed pointer send to device
// merge kv into a packed pointer send to device
b0_gs_ns_ks
.
ForEach
(
b0_gs_ns_ks
.
ForEach
(
...
@@ -189,20 +219,20 @@ int run(int argc, char* argv[])
...
@@ -189,20 +219,20 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeCrossAttnInvoker
();
auto
invoker
=
gemm
.
MakeCrossAttnInvoker
();
auto
argument
=
auto
argument
=
gemm
.
MakeCrossAttnArgument
(
static_cast
<
ADataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
gemm
.
MakeCrossAttnArgument
(
static_cast
<
ADataType
*>
(
q_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
kv_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
kv_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
G0
,
batch_size
,
M
,
q_sequence_length
,
N
,
kv_sequence_length
,
G1
,
head_num
,
K
,
head_dim
,
alpha
);
alpha
);
// if(!gemm.IsSupportedArgument(argument))
// if(!gemm.IsSupportedArgument(argument))
...
@@ -212,13 +242,17 @@ int run(int argc, char* argv[])
...
@@ -212,13 +242,17 @@ int run(int argc, char* argv[])
// return 0;
// return 0;
// }
// }
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
batch_size
*
head_num
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
q_sequence_length
)
*
kv_sequence_length
*
head_dim
*
2
+
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
size_t
(
q_sequence_length
)
*
kv_sequence_length
*
head_dim
*
2
)
*
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
q_sequence_length
*
head_dim
+
sizeof
(
B0DataType
)
*
head_dim
*
kv_sequence_length
+
sizeof
(
B1DataType
)
*
kv_sequence_length
*
head_dim
+
sizeof
(
CDataType
)
*
q_sequence_length
*
head_dim
)
*
BatchCount
;
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -237,22 +271,26 @@ int run(int argc, char* argv[])
...
@@ -237,22 +271,26 @@ int run(int argc, char* argv[])
{
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
q_sequence_length
,
head_dim
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
head_dim
,
kv_sequence_length
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
kv_sequence_length
,
head_dim
});
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0DataType
>
acc0_g_m_n
(
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
{
BatchCount
,
q_sequence_length
,
kv_sequence_length
});
// scratch object after gemm0
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
q_sequence_length
,
kv_sequence_length
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
(
{
BatchCount
,
q_sequence_length
,
head_dim
});
// scratch object after gemm1
// permute
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
// gemm 0
// gemm 0
...
@@ -264,7 +302,7 @@ int run(int argc, char* argv[])
...
@@ -264,7 +302,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
// masking
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
N
);
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
kv_sequence_length
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
@@ -294,7 +332,7 @@ int run(int argc, char* argv[])
...
@@ -294,7 +332,7 @@ int run(int argc, char* argv[])
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
head_num
+
g1
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
...
@@ -330,8 +368,10 @@ int run(int argc, char* argv[])
...
@@ -330,8 +368,10 @@ int run(int argc, char* argv[])
std
::
cout
<<
"---------------------------------------------------------------------------------"
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
"-----------"
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
std
::
cout
<<
"Problem Size: BatchCount: "
<<
batch_size
<<
", HeadNum: "
<<
head_num
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
<<
", q_sequence_length: "
<<
q_sequence_length
<<
", kv_sequence_length: "
<<
kv_sequence_length
<<
", head_dim: "
<<
head_dim
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
"-----------"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc
View file @
ccaea50e
...
@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
...
@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
...
...
example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc
View file @
ccaea50e
...
@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
...
@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
...
...
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
View file @
ccaea50e
...
@@ -9,20 +9,17 @@ int run(int argc, char* argv[])
...
@@ -9,20 +9,17 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
256
;
ck
::
index_t
sequence_length
=
256
;
ck
::
index_t
N
=
256
;
ck
::
index_t
head_dim
=
80
;
ck
::
index_t
K
=
80
;
ck
::
index_t
O
=
80
;
// Output shape C[
G0, M, G1, O
]. Batch dim, outer dim, inner
dim must match GEMM shape
// Output shape C[
batch_size, sequence_length, head_num, head_dim
]. Batch dim, outer dim, inner
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
//
dim must match GEMM shape
C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
C_g0_m_g1_o =
//
C_g0_m_g1_o =
permute(C_g0_g1_m_o, [0, 2, 1, 3])
// permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
2
;
ck
::
index_t
batch_size
=
2
;
ck
::
index_t
G1
=
8
;
ck
::
index_t
head_num
=
8
;
float
alpha
=
1
;
float
alpha
=
1
;
bool
input_permute
=
true
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
if
(
argc
==
1
)
if
(
argc
==
1
)
...
@@ -35,58 +32,81 @@ int run(int argc, char* argv[])
...
@@ -35,58 +32,81 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
13
)
else
if
(
argc
==
9
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
sequence_length
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
head_dim
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
batch_size
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
head_num
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
alpha
=
std
::
stof
(
argv
[
10
]);
input_permute
=
std
::
stoi
(
argv
[
11
]);
alpha
=
std
::
stof
(
argv
[
8
]);
output_permute
=
std
::
stoi
(
argv
[
12
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 11: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg4 to 7: sequence_length, head_dim, batch_size, head_num
\n
"
);
printf
(
"arg10: scale (alpha)
\n
"
);
printf
(
"arg8: scale (alpha)
\n
"
);
printf
(
"arg11 to 12: input / output permute
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
batch_size
,
head_num
,
sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
head_num
*
head_dim
,
1
}
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
// A layout [batch_size, sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
sequence_length
*
head_dim
,
sequence_length
*
head_dim
,
head_dim
,
1
};
// A layout [batch_size, head_num, sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
batch_size
,
head_num
,
sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
head_num
*
head_dim
,
1
}
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
// B0 layout [batch_size, sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
sequence_length
*
head_dim
,
sequence_length
*
head_dim
,
head_dim
,
1
};
// B0 layout [batch_size, head_num, sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
batch_size
,
head_num
,
head_dim
,
sequence_length
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
sequence_length
*
head_num
*
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
head_dim
,
1
,
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
head_num
*
head_dim
}
// B1 layout [batch_size, sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
sequence_length
*
head_dim
,
sequence_length
*
head_dim
,
1
,
head_dim
};
// B1 layout [batch_size, head_num, sequence_length, head_dim]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
batch_size
,
head_num
,
sequence_length
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
sequence_length
*
head_num
*
head_dim
,
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
head_dim
,
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
head_num
*
head_dim
,
1
}
// C layout [batch_size, sequence_length, head_num, head_dim]
:
std
::
vector
<
ck
::
index_t
>
{
head_num
*
sequence_length
*
head_dim
,
sequence_length
*
head_dim
,
head_dim
,
1
};
// C layout [batch_size, head_num, sequence_length, head_dim]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
@@ -158,9 +178,14 @@ int run(int argc, char* argv[])
...
@@ -158,9 +178,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
}
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
3
,
K
};
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_lengths
{
batch_size
,
head_num
,
sequence_length
,
3
,
head_dim
};
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
std
::
vector
<
ck
::
index_t
>
qkv_gs_ms_ks_strides
=
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
3
*
K
,
3
*
K
,
G1
*
3
*
K
,
K
,
1
};
// qkv layout [G0, M, G1, 3, K]
sequence_length
*
head_num
*
3
*
head_dim
,
3
*
head_dim
,
head_num
*
3
*
head_dim
,
head_dim
,
1
};
// qkv layout [batch_size, sequence_length, head_num, 3, head_dim]
Tensor
<
ADataType
>
qkv_gs_ms_ks
(
qkv_gs_ms_ks_lengths
,
qkv_gs_ms_ks_strides
);
Tensor
<
ADataType
>
qkv_gs_ms_ks
(
qkv_gs_ms_ks_lengths
,
qkv_gs_ms_ks_strides
);
// merge qkv into a packed pointer send to device
// merge qkv into a packed pointer send to device
a_gs_ms_ks
.
ForEach
(
a_gs_ms_ks
.
ForEach
(
...
@@ -190,18 +215,18 @@ int run(int argc, char* argv[])
...
@@ -190,18 +215,18 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeSelfAttnInvoker
();
auto
invoker
=
gemm
.
MakeSelfAttnInvoker
();
auto
argument
=
auto
argument
=
gemm
.
MakeSelfAttnArgument
(
static_cast
<
ADataType
*>
(
qkv_device_buf
.
GetDeviceBuffer
()),
gemm
.
MakeSelfAttnArgument
(
static_cast
<
ADataType
*>
(
qkv_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
G0
,
batch_size
,
M
,
sequence_length
,
G1
,
head_num
,
K
,
head_dim
,
alpha
);
alpha
);
// if(!gemm.IsSupportedArgument(argument))
// if(!gemm.IsSupportedArgument(argument))
...
@@ -211,13 +236,17 @@ int run(int argc, char* argv[])
...
@@ -211,13 +236,17 @@ int run(int argc, char* argv[])
// return 0;
// return 0;
// }
// }
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
batch_size
*
head_num
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
sequence_length
)
*
sequence_length
*
head_dim
*
2
+
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
size_t
(
sequence_length
)
*
sequence_length
*
head_dim
*
2
)
*
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
sequence_length
*
head_dim
+
sizeof
(
B0DataType
)
*
head_dim
*
sequence_length
+
sizeof
(
B1DataType
)
*
sequence_length
*
head_dim
+
sizeof
(
CDataType
)
*
sequence_length
*
head_dim
)
*
BatchCount
;
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -236,22 +265,25 @@ int run(int argc, char* argv[])
...
@@ -236,22 +265,25 @@ int run(int argc, char* argv[])
{
{
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
sequence_length
,
head_dim
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
head_dim
,
sequence_length
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
sequence_length
,
head_dim
});
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0DataType
>
acc0_g_m_n
(
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
{
BatchCount
,
sequence_length
,
sequence_length
});
// scratch object after gemm0
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
Tensor
<
ADataType
>
a1_g_m_n
(
{
BatchCount
,
sequence_length
,
sequence_length
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
(
{
BatchCount
,
sequence_length
,
head_dim
});
// scratch object after gemm1
// permute
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
(
idx
[
0
]
*
head_num
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
});
// gemm 0
// gemm 0
...
@@ -263,7 +295,7 @@ int run(int argc, char* argv[])
...
@@ -263,7 +295,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// masking
// masking
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
N
);
const
auto
mask
=
typename
DeviceMHAInstance
::
C0MatrixMask
(
sequence_length
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
if
(
mask
.
IsMaskedElement
(
idx
[
1
],
idx
[
2
]))
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
...
@@ -293,7 +325,7 @@ int run(int argc, char* argv[])
...
@@ -293,7 +325,7 @@ int run(int argc, char* argv[])
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
head_num
+
g1
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
...
@@ -329,8 +361,9 @@ int run(int argc, char* argv[])
...
@@ -329,8 +361,9 @@ int run(int argc, char* argv[])
std
::
cout
<<
"---------------------------------------------------------------------------------"
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
"-----------"
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"Problem Size: BatchCount: "
<<
G0
<<
", HeadNum: "
<<
G1
<<
", M: "
<<
M
std
::
cout
<<
"Problem Size: BatchCount: "
<<
batch_size
<<
", HeadNum: "
<<
head_num
<<
", N: "
<<
N
<<
", K: "
<<
K
<<
", O: "
<<
O
<<
std
::
endl
;
<<
", sequence_length: "
<<
sequence_length
<<
", head_dim: "
<<
head_dim
<<
std
::
endl
;
std
::
cout
<<
"---------------------------------------------------------------------------------"
std
::
cout
<<
"---------------------------------------------------------------------------------"
"-----------"
"-----------"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
View file @
ccaea50e
...
@@ -83,12 +83,34 @@ using DeviceMHAFactory =
...
@@ -83,12 +83,34 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
32
,
// Gemm 0
// Gemm 0
16
,
128
,
64
,
8
,
8
,
16
,
32
,
160
,
8
,
8
,
// Gemm 1
// Gemm 1
64
,
64
,
8
,
80
,
32
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
2
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
2
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
// Gemm 0
16
,
64
,
80
,
8
,
8
,
// Gemm 1
80
,
64
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
4
,
5
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
...
@@ -105,12 +127,12 @@ using DeviceMHAFactory =
...
@@ -105,12 +127,12 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
32
,
32
,
// Gemm 0
// Gemm 0
16
,
64
,
6
4
,
8
,
8
,
16
,
64
,
4
8
,
8
,
8
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
...
@@ -129,16 +151,16 @@ using DeviceMHAFactory =
...
@@ -129,16 +151,16 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
64
,
// Gemm 0
// Gemm 0
32
,
128
,
64
,
8
,
8
,
32
,
64
,
48
,
8
,
8
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
...
@@ -151,16 +173,38 @@ using DeviceMHAFactory =
...
@@ -151,16 +173,38 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
64
,
// Gemm 0
// Gemm 0
32
,
64
,
64
,
8
,
8
,
32
,
64
,
80
,
8
,
8
,
// Gemm 1
// Gemm 1
64
,
64
,
8
,
80
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
1
,
4
,
5
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
32
,
1
,
2
>
,
8
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
64
,
// Gemm 0
32
,
32
,
160
,
8
,
8
,
// Gemm 1
80
,
32
,
8
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
2
,
5
,
// ABlockTransfer MK -> K0 M K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
S
<
2
,
4
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
...
@@ -175,20 +219,20 @@ using DeviceMHAFactory =
...
@@ -175,20 +219,20 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
128
,
// Gemm 0
// Gemm 0
64
,
128
,
64
,
8
,
8
,
64
,
128
,
80
,
8
,
8
,
// Gemm 1
// Gemm 1
64
,
64
,
8
,
80
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
8
,
5
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
1
6
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
6
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
...
@@ -197,45 +241,45 @@ using DeviceMHAFactory =
...
@@ -197,45 +241,45 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
128
,
128
,
// Gemm 0
// Gemm 0
64
,
64
,
64
,
8
,
8
,
64
,
192
,
48
,
8
,
8
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
4
,
4
,
1
,
12
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
1
6
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
6
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
128
,
// Gemm 0
// Gemm 0
128
,
128
,
64
,
8
,
8
,
64
,
64
,
48
,
8
,
8
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
4
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
S
<
2
,
8
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
64
,
1
,
2
>
,
8
,
MaskingSpec
>
,
MaskingSpec
>
,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
...
@@ -243,18 +287,18 @@ using DeviceMHAFactory =
...
@@ -243,18 +287,18 @@ using DeviceMHAFactory =
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
1
,
256
,
256
,
// Gemm 0
// Gemm 0
128
,
1
28
,
64
,
8
,
8
,
128
,
1
92
,
48
,
8
,
4
,
// Gemm 1
// Gemm 1
6
4
,
64
,
8
,
4
8
,
64
,
8
,
16
,
16
,
16
,
16
,
16
,
16
,
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
8
,
4
,
1
,
12
,
3
,
// ABlockTransfer MK -> K0 M K1
// ABlockTransfer MK -> K0 M K1
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
// B0BlockTransfer LK -> K0 L K1
// B0BlockTransfer LK -> K0 L K1
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
// B1BlockTransfer NL -> L0 N L1
// B1BlockTransfer NL -> L0 N L1
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
S
<
2
,
16
,
8
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
false
,
// CShuffleBlockTransfer MN
// CShuffleBlockTransfer MN
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
1
,
1
,
S
<
1
,
128
,
1
,
2
>
,
8
,
MaskingSpec
>
MaskingSpec
>
...
...
example/64_fpAintB_gemm/CMakeLists.txt
View file @
ccaea50e
if
(
GPU_TARGETS MATCHES
"gfx11
00"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_custom_target
(
example_fpAintB_gemm_wmma
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_example_executable
(
example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp
)
add_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
add_dependencies
(
example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
ccaea50e
...
@@ -56,8 +56,7 @@ __global__ void
...
@@ -56,8 +56,7 @@ __global__ void
bool
input_permute
,
bool
input_permute
,
bool
output_permute
)
bool
output_permute
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -162,7 +161,7 @@ __global__ void
...
@@ -162,7 +161,7 @@ __global__ void
ignore
=
G1
;
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
input_permute
;
ignore
=
output_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Self-Attention
// Self-Attention
...
@@ -188,8 +187,7 @@ __global__ void
...
@@ -188,8 +187,7 @@ __global__ void
index_t
head_size
,
index_t
head_size
,
float
alpha
)
float
alpha
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -294,7 +292,7 @@ __global__ void
...
@@ -294,7 +292,7 @@ __global__ void
ignore
=
head_count
;
ignore
=
head_count
;
ignore
=
head_size
;
ignore
=
head_size
;
ignore
=
alpha
;
ignore
=
alpha
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Cross-Attention
// Cross-Attention
// Self-Attention
// Self-Attention
...
@@ -323,8 +321,7 @@ __global__ void
...
@@ -323,8 +321,7 @@ __global__ void
index_t
head_size
,
index_t
head_size
,
float
alpha
)
float
alpha
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -435,7 +432,7 @@ __global__ void
...
@@ -435,7 +432,7 @@ __global__ void
ignore
=
head_count
;
ignore
=
head_count
;
ignore
=
head_size
;
ignore
=
head_size
;
ignore
=
alpha
;
ignore
=
alpha
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// MN = MK * KL * LN
...
@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0
#if 0
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
ccaea50e
...
@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
is_same_v
<
AccDataType
,
int32_t
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ccaea50e
...
@@ -498,94 +498,95 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -498,94 +498,95 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
};
};
static
bool
IsSupported
Argument
(
const
Argument
&
arg
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check vector load/store
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
}
// check vector laod of B
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
// FIXME: not rigorous
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
return
false
;
{
return
false
;
}
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
// FIXME: not rigorous
return
false
;
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
}
else
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
else
{
return
false
;
}
// check vector load of Ds
// check vector load of Ds
// only support RowMajor for now
// only support RowMajor for now
bool
all_valid
=
true
;
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
all_valid
=
false
;
}
});
if
(
!
all_valid
)
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
{
return
false
;
all_valid
=
false
;
}
}
});
// check vector store of E
if
(
!
all_valid
)
// only support RowMajor for now
{
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
return
false
;
{
}
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
// check vector store of E
return
false
;
// only support RowMajor for now
}
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
}
{
else
if
(
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
}
else
{
return
false
;
}
return
true
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
return
IsSupported
(
arg
.
MRaw_
,
arg
.
NRaw_
,
arg
.
KRaw_
)
and
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
struct
Descriptor
{
static
constexpr
auto
ds_tuple
()
{
return
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
DsDesc
{});
}
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ds_tuple
())
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{})))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{})))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_tuple
()))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k
;
BGridDesc_N_K
b_grid_desc_n_k
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
EGridDesc_M_N
e_grid_desc_m_n
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CDEElementwiseOperation
cde_element_op
;
// for checking vector load/store
index_t
MRaw
;
index_t
NRaw
;
index_t
KRaw
;
bool
has_main_k_block_loop
=
true
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CDEElementwiseOperation
cde_element_op_
)
:
a_grid_desc_m_k
{
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
)},
b_grid_desc_n_k
{
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b
)},
ds_grid_desc_m_n
{
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
)},
e_grid_desc_m_n
{
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
)},
a_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
)},
b_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
))},
e_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
)},
block_2_etile_map
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
cde_element_op
{
cde_element_op_
},
MRaw
{
e
.
GetLength
(
I0
)},
NRaw
{
e
.
GetLength
(
I1
)},
KRaw
{
a
.
GetLength
(
I1
)}
{
}
constexpr
bool
IsValid
()
const
{
return
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
);
}
constexpr
index_t
GetBlockSize
()
const
{
return
BlockSize
;
}
constexpr
index_t
GetGridSize
()
const
{
return
block_2_etile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
}
};
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
CDEElementwiseOperation
cde_element_op
=
CDEElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
DsDesc
,
EDesc
>
(
a
,
b
,
ds
,
e
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
template
<
class
Desc
,
class
DsPointer
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
)
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
assert
(
desc
.
IsValid
());
if
(
desc
.
has_main_k_block_loop
)
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
View file @
ccaea50e
...
@@ -61,8 +61,7 @@ __global__ void
...
@@ -61,8 +61,7 @@ __global__ void
bool
input_permute
,
bool
input_permute
,
bool
output_permute
)
bool
output_permute
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off
// clang-format off
// ***************************************************
// ***************************************************
...
@@ -169,7 +168,7 @@ __global__ void
...
@@ -169,7 +168,7 @@ __global__ void
ignore
=
G1
;
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
input_permute
;
ignore
=
output_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
...
@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
...
@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
is_navi3_supported
())
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
...
@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
...
@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0
#if 0
static bool IsSupportedArgument(const Argument& arg)
static bool IsSupportedArgument(const Argument& arg)
{
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
{
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment