Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6dbced07
Commit
6dbced07
authored
Sep 26, 2023
by
letaoqin
Browse files
change mha infer class and file name
parent
63ea1d70
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
230 additions
and
227 deletions
+230
-227
example/52_flash_atten_bias/CMakeLists.txt
example/52_flash_atten_bias/CMakeLists.txt
+3
-3
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_infer.cpp
...tten_bias/batched_gemm_multihead_attention_bias_infer.cpp
+68
-67
example/52_flash_atten_bias/batched_gemm_multihead_attention_infer.cpp
...ash_atten_bias/batched_gemm_multihead_attention_infer.cpp
+68
-67
example/52_flash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
...lash_atten_bias/grouped_mutihead_attention_bias_infer.cpp
+68
-67
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
+7
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
+15
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp
+1
-1
No files found.
example/52_flash_atten_bias/CMakeLists.txt
View file @
6dbced07
add_example_executable
(
example_batched_multihead_attention_
forward
batched_gemm_multihead_attention_
forward
.cpp
)
add_example_executable
(
example_batched_multihead_attention_
infer
batched_gemm_multihead_attention_
infer
.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_
forward
batched_gemm_multihead_attention_bias_
forward
.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_
infer
batched_gemm_multihead_attention_bias_
infer
.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_
forward
grouped_mutihead_attention_bias_
forward
.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_
infer
grouped_mutihead_attention_bias_
infer
.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp
)
add_example_executable
(
example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp
)
add_example_executable
(
example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp
)
...
...
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_
forward
.cpp
→
example/52_flash_atten_bias/batched_gemm_multihead_attention_bias_
infer
.cpp
View file @
6dbced07
...
@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
...
@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/52_flash_atten_bias/batched_gemm_multihead_attention_
forward
.cpp
→
example/52_flash_atten_bias/batched_gemm_multihead_attention_
infer
.cpp
View file @
6dbced07
...
@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
...
@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/52_flash_atten_bias/grouped_mutihead_attention_bias_
forward
.cpp
→
example/52_flash_atten_bias/grouped_mutihead_attention_bias_
infer
.cpp
View file @
6dbced07
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -66,7 +66,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
...
@@ -66,7 +66,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionForward_Xdl
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_
fwd
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_
infer
_xdl_cshuffle.hpp
View file @
6dbced07
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -44,7 +44,7 @@ __global__ void
...
@@ -44,7 +44,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_multiple_head_flash_attention_
forward
(
kernel_batched_multiple_head_flash_attention_
infer
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
D0DataType
*
p_d0_grid
,
const
D0DataType
*
p_d0_grid
,
...
@@ -205,7 +205,7 @@ template <index_t NumDimG,
...
@@ -205,7 +205,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
int
D0sTransferSrcScalarPerVector
=
4
,
int
D0sTransferSrcScalarPerVector
=
4
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttention
Forward_Xdl
struct
DeviceBatchedMultiheadAttention
Infer_Xdl_CShuffle
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -243,7 +243,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -243,7 +243,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttention
Forward_Xdl
;
using
DeviceOp
=
DeviceBatchedMultiheadAttention
Infer_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
};
};
using
GridwiseGemm
=
GridwiseMultiHeadFlashAttention
Forward
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseMultiHeadFlashAttention
Infer
_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
D0DataType
,
D0DataType
,
GemmAccDataType
,
GemmAccDataType
,
...
@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multiple_head_flash_attention_
forward
<
const
auto
kernel
=
kernel_batched_multiple_head_flash_attention_
infer
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
D0DataType
,
D0DataType
,
...
@@ -925,7 +925,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
...
@@ -925,7 +925,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttention
Forward_Xdl
"
str
<<
"DeviceBatchedMultiheadAttention
Infer_Xdl_CShuffle
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_
fwd
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_
infer
_xdl_cshuffle.hpp
View file @
6dbced07
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_mha_infer.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_mha_infer.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_
fwd
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_
infer
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -35,7 +35,7 @@ __global__ void
...
@@ -35,7 +35,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_
gemm_softmax_gemm_xdl_cshuffle_v1
(
kernel_grouped_
multiple_head_flash_attention_infer
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -194,7 +194,7 @@ template <index_t NumDimG,
...
@@ -194,7 +194,7 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttention
Forward_Xdl
struct
DeviceGroupedMultiheadAttention
Infer_Xdl_CShuffle
:
public
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
:
public
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -230,7 +230,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
...
@@ -230,7 +230,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceGroupedMultiheadAttention
Forward_Xdl
;
using
DeviceOp
=
DeviceGroupedMultiheadAttention
Infer_Xdl_CShuffle
;
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
using
ProblemDesc
=
typename
DeviceGroupedMultiheadAttentionInfer
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -382,7 +382,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
...
@@ -382,7 +382,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseMultiHeadFlashAttention
Forward
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseMultiHeadFlashAttention
Infer
_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
Acc0BiasDataType
,
Acc0BiasDataType
,
GemmAccDataType
,
GemmAccDataType
,
...
@@ -698,7 +698,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
...
@@ -698,7 +698,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_
gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
kernel_grouped_
multiple_head_flash_attention_infer
<
GridwiseGemm
,
D0DataType
,
D0DataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -944,7 +944,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
...
@@ -944,7 +944,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttention
Forward_Xdl
"
str
<<
"DeviceGroupedMultiheadAttention
Infer_Xdl_CShuffle
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_
fwd
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_
infer
_xdl_cshuffle.hpp
View file @
6dbced07
...
@@ -86,7 +86,7 @@ template <typename FloatAB,
...
@@ -86,7 +86,7 @@ template <typename FloatAB,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseMultiHeadFlashAttention
Forward
_Xdl_CShuffle
struct
GridwiseMultiHeadFlashAttention
Infer
_Xdl_CShuffle
{
{
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
static_assert
(
D0BlockTransferSrcScalarPerVector
==
1
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
2
||
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment