Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c7fa22cb
Commit
c7fa22cb
authored
Feb 08, 2023
by
danyao12
Browse files
rename attn fwd pass files
parent
cc974f0f
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
62 additions
and
62 deletions
+62
-62
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+4
-4
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_bf16.cpp
...softmax_gemm/batched_multihead_attention_forward_bf16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_fp16.cpp
...softmax_gemm/batched_multihead_attention_forward_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
...softmax_gemm/grouped_multihead_attention_forward_bf16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
...softmax_gemm/grouped_multihead_attention_forward_fp16.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+0
-0
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+0
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+25
-25
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+18
-18
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+1
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
c7fa22cb
...
@@ -3,12 +3,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_
...
@@ -3,12 +3,12 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_train_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_train_xdl_bf16 grouped_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_fp16 grouped_multihead_attention_forward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_fp16 batched_multihead_attention_forward_fp16.cpp
)
add_example_executable
(
example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_
gemm_scale_softmax_gemm_permute_train_xdl
_bf16.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_
multihead_attention_forward
_bf16.cpp
View file @
c7fa22cb
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
#include "run_batched_
gemm_scale_softmax_gemm_permute_train
.inc"
#include "run_batched_
multihead_attention_forward
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_
gemm_scale_softmax_gemm_permute_train_xdl
_fp16.cpp
→
example/32_batched_gemm_scale_softmax_gemm/batched_
multihead_attention_forward
_fp16.cpp
View file @
c7fa22cb
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
#include "run_batched_
gemm_scale_softmax_gemm_permute_train
.inc"
#include "run_batched_
multihead_attention_forward
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/grouped_
gemm_scale_softmax_gemm_permute_train_xdl
_bf16.cpp
→
example/32_batched_gemm_scale_softmax_gemm/grouped_
multihead_attention_forward
_bf16.cpp
View file @
c7fa22cb
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
#include "run_grouped_
gemm_scale_softmax_gemm_permute_train
.inc"
#include "run_grouped_
multihead_attention_forward
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/grouped_
gemm_scale_softmax_gemm_permute_train_xdl
_fp16.cpp
→
example/32_batched_gemm_scale_softmax_gemm/grouped_
multihead_attention_forward
_fp16.cpp
View file @
c7fa22cb
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
...
@@ -68,7 +68,7 @@ static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecial
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
...
@@ -157,6 +157,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp
,
B1ElementOp
,
CElementOp
>
;
CElementOp
>
;
#include "run_grouped_
gemm_scale_softmax_gemm_permute_train
.inc"
#include "run_grouped_
multihead_attention_forward
.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/run_batched_
gemm_scale_softmax_gemm_permute_train
.inc
→
example/32_batched_gemm_scale_softmax_gemm/run_batched_
multihead_attention_forward
.inc
View file @
c7fa22cb
File moved
example/32_batched_gemm_scale_softmax_gemm/run_grouped_
gemm_scale_softmax_gemm_permute_train
.inc
→
example/32_batched_gemm_scale_softmax_gemm/run_grouped_
multihead_attention_forward
.inc
View file @
c7fa22cb
File moved
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
c7fa22cb
...
@@ -84,7 +84,7 @@ template <index_t NumDimG,
...
@@ -84,7 +84,7 @@ template <index_t NumDimG,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatched
GemmSoftmaxGemmPermuteTrain
:
public
BaseOperator
struct
DeviceBatched
MultiheadAttentionForward
:
public
BaseOperator
{
{
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
c7fa22cb
...
@@ -88,7 +88,7 @@ template <index_t NumDimG,
...
@@ -88,7 +88,7 @@ template <index_t NumDimG,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceGrouped
GemmSoftmaxGemmPermuteTrain
:
public
BaseOperator
struct
DeviceGrouped
MultiheadAttentionForward
:
public
BaseOperator
{
{
struct
ProblemDesc
struct
ProblemDesc
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_
multihead_attention_forward
_xdl_cshuffle.hpp
View file @
c7fa22cb
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -47,7 +47,7 @@ __global__ void
...
@@ -47,7 +47,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
(
kernel_batched_
multiheadattention_forward
_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
...
@@ -205,8 +205,8 @@ template <index_t NumDimG,
...
@@ -205,8 +205,8 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
struct
DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle
:
public
DeviceBatched
GemmSoftmaxGemmPermuteTrain
<
NumDimG
,
:
public
DeviceBatched
MultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
...
@@ -244,7 +244,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -244,7 +244,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -382,7 +382,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -382,7 +382,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatched
GemmSoftmaxGemmTrain
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatched
MultiheadAttentionForward
_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -648,7 +648,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -648,7 +648,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
<
const
auto
kernel
=
kernel_batched_
multiheadattention_forward
_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
...
@@ -958,7 +958,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -958,7 +958,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatched
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle"
str
<<
"DeviceBatched
MultiheadAttentionForward
_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_
gemm_softmax_gemm_permute_train
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_grouped_
multihead_attention_forward
_xdl_cshuffle.hpp
View file @
c7fa22cb
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
multihead_attention_forward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -37,7 +37,7 @@ __global__ void
...
@@ -37,7 +37,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_
gemm_softmax_gemm
_xdl_cshuffle
_v2
(
kernel_grouped_
multiheadattention_forward
_xdl_cshuffle
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -197,8 +197,8 @@ template <index_t NumDimG,
...
@@ -197,8 +197,8 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
struct
DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle
:
public
DeviceGrouped
GemmSoftmaxGemmPermuteTrain
<
NumDimG
,
:
public
DeviceGrouped
MultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
...
@@ -236,8 +236,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -236,8 +236,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle
;
using
ProblemDesc
=
typename
DeviceGrouped
GemmSoftmaxGemmPermuteTrain
<
NumDimG
,
using
ProblemDesc
=
typename
DeviceGrouped
MultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
...
@@ -392,7 +392,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -392,7 +392,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatched
GemmSoftmaxGemmTrain
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatched
MultiheadAttentionForward
_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -705,7 +705,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -705,7 +705,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_
gemm_softmax_gemm
_xdl_cshuffle
_v2
<
GridwiseGemm
,
kernel_grouped_
multiheadattention_forward
_xdl_cshuffle
<
GridwiseGemm
,
GemmAccDataType
,
GemmAccDataType
,
GroupKernelArg
,
GroupKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -969,7 +969,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -969,7 +969,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGrouped
GemmSoftmaxGemmPermute_Train
_Xdl_CShuffle"
str
<<
"DeviceGrouped
MultiheadAttentionForward
_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_
gemm_softmax_gemm
_xdl_cshuffle
_v2
.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_batched_
multihead_attention_forward
_xdl_cshuffle.hpp
View file @
c7fa22cb
...
@@ -83,7 +83,7 @@ template <typename FloatAB,
...
@@ -83,7 +83,7 @@ template <typename FloatAB,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatched
GemmSoftmaxGemmTrain
_Xdl_CShuffle
struct
GridwiseBatched
MultiheadAttentionForward
_Xdl_CShuffle
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment