Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8ef97116
Commit
8ef97116
authored
Mar 07, 2023
by
fsx950223
Browse files
update
parent
d9579dc8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
5 deletions
+10
-5
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
...oftmax_gemm/grouped_multihead_attention_backward_fp16.cpp
+7
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...grouped_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+3
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
View file @
8ef97116
...
@@ -25,7 +25,7 @@ Kernel outputs:
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define
VERSION 1
#define
DIM 64
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -62,6 +62,7 @@ using QKVElementOp = PassThrough;
...
@@ -62,6 +62,7 @@ using QKVElementOp = PassThrough;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
F16
;
using
DataType
=
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
...
@@ -89,7 +90,7 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
...
@@ -89,7 +90,7 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
#if
VERSION == 1
#if
DIM >=128
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
...
@@ -156,7 +157,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
...
@@ -156,7 +157,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#elif
VERSION == 2
#elif
DIM >64
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
...
@@ -223,7 +224,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
...
@@ -223,7 +224,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
>
;
// MaskingSpecialization
#elif
VERSION == 3
#elif
DIM >32
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
...
@@ -231,6 +232,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
...
@@ -231,6 +232,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
DataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
...
@@ -297,6 +299,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
...
@@ -297,6 +299,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
DataType
,
DataType
,
GemmDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
8ef97116
...
@@ -151,6 +151,7 @@ template <index_t NumDimG,
...
@@ -151,6 +151,7 @@ template <index_t NumDimG,
index_t
NumDimK
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
index_t
NumDimO
,
// NumDimGemm1N
typename
DataType
,
typename
DataType
,
typename
GemmDataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc0BiasDataType
,
...
@@ -535,9 +536,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -535,9 +536,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
<
DataType
,
// TODO: distinguish A/B datatype
DataType
,
// TODO: distinguish A/B datatype
LSE
DataType
,
Gemm
DataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
LSEDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment