Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
075abf92
Commit
075abf92
authored
Mar 08, 2023
by
danyao12
Browse files
bwd template rename
parent
c56f28b0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
24 additions
and
25 deletions
+24
-25
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+6
-6
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+8
-9
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+1
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
075abf92
...
@@ -36,8 +36,8 @@ Kernel outputs:
...
@@ -36,8 +36,8 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_
pt
1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_
v
1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle
_v2
.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -97,7 +97,7 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpeciali
...
@@ -97,7 +97,7 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpeciali
// If 64 < Headdim/K/O <= 128, ues prototype2 2nd template.
// If 64 < Headdim/K/O <= 128, ues prototype2 2nd template.
#if(RANGE_HDKO == 0)
#if(RANGE_HDKO == 0)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -165,7 +165,7 @@ using DeviceGemmInstance =
...
@@ -165,7 +165,7 @@ using DeviceGemmInstance =
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#elif(RANGE_HDKO == 1)
#elif(RANGE_HDKO == 1)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -233,7 +233,7 @@ using DeviceGemmInstance =
...
@@ -233,7 +233,7 @@ using DeviceGemmInstance =
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstance =
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
<
// NumDimG,
// NumDimG,
// NumDimM,
// NumDimM,
// NumDimN,
// NumDimN,
...
@@ -301,7 +301,7 @@ using DeviceGemmInstance =
...
@@ -301,7 +301,7 @@ using DeviceGemmInstance =
// MaskingSpec>; // MaskingSpecialization
// MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2)
#elif(RANGE_HDKO == 2)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
075abf92
...
@@ -43,8 +43,8 @@ Kernel outputs:
...
@@ -43,8 +43,8 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_
pt
1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_
v
1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle
_v2
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
@@ -176,7 +176,7 @@ using DeviceGemmInstanceFWD =
...
@@ -176,7 +176,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -314,7 +314,7 @@ using DeviceGemmInstanceFWD =
...
@@ -314,7 +314,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -382,7 +382,7 @@ using DeviceGemmInstanceBWD =
...
@@ -382,7 +382,7 @@ using DeviceGemmInstanceBWD =
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
// using DeviceGemmInstanceBWD =
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
<
// NumDimG,
// NumDimG,
// NumDimM,
// NumDimM,
// NumDimN,
// NumDimN,
...
@@ -520,7 +520,7 @@ using DeviceGemmInstanceFWD =
...
@@ -520,7 +520,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
using
DeviceGemmInstanceBWD
=
using
DeviceGemmInstanceBWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_
pt
1.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_
v
1.hpp
View file @
075abf92
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
pt
1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
v
1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -50,10 +50,9 @@ template <typename GridwiseGemm,
...
@@ -50,10 +50,9 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_
pt
1
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_
v
1
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -233,7 +232,7 @@ template <index_t NumDimG,
...
@@ -233,7 +232,7 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -255,7 +254,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -255,7 +254,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -597,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -597,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
<
DataType
,
// TODO: distinguish A/B datatype
DataType
,
// TODO: distinguish A/B datatype
GemmDataType
,
GemmDataType
,
GemmAccDataType
,
GemmAccDataType
,
...
@@ -900,7 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -900,7 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_
pt
1
<
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_
v
1
<
GridwiseGemm
,
GridwiseGemm
,
DataType
,
DataType
,
ZDataType
,
ZDataType
,
...
@@ -1231,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1231,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1"
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle
_v2
.hpp
View file @
075abf92
...
@@ -231,7 +231,7 @@ template <index_t NumDimG,
...
@@ -231,7 +231,7 @@ template <index_t NumDimG,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -253,7 +253,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -253,7 +253,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -1230,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -1230,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle"
str
<<
"DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
_V2
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
pt
1.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_
v
1.hpp
View file @
075abf92
...
@@ -85,7 +85,7 @@ template <typename DataType,
...
@@ -85,7 +85,7 @@ template <typename DataType,
bool
PadN
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
PT
1
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_
V
1
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment