Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dded3450
Commit
dded3450
authored
May 29, 2023
by
danyao12
Browse files
prototype2 qloop direction w/ layout change
parent
59088eca
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3873 additions
and
7 deletions
+3873
-7
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
..._softmax_gemm/batched_multihead_attention_backward_v4.cpp
+76
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
+1306
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
+2491
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
View file @
dded3450
...
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -37,7 +37,7 @@ Kernel outputs:
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v
2
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v
5
.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
@@ -95,7 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
...
...
@@ -298,6 +298,75 @@ using DeviceGemmInstance =
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
// NumDimK,
// NumDimO,
// InputDataType,
// OutputDataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
// Acc0BiasDataType,
// Acc1BiasDataType,
// AccDataType,
// ShuffleDataType,
// QKVElementOp,
// QKVElementOp,
// Scale,
// QKVElementOp,
// YElementOp,
// GemmSpec,
// TensorSpecQ,
// TensorSpecK,
// TensorSpecV,
// TensorSpecY,
// 1,
// 256,
// 128, // MPerBlock
// 128, // NPerBlock
// 64, // KPerBlock
// 128, // Gemm1NPerBlock
// 32, // Gemm1KPerBlock
// 8, // AK1
// 8, // BK1
// 2, // B1K1
// 32, // MPerXDL
// 32, // NPerXDL
// 4, // MXdlPerWave
// 1, // NXdlPerWave
// 4, // Gemm1NXdlPerWave
// 2, // Gemm2NXdlPerWave
// S<4, 64, 1>, // ABlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<4, 64, 1>, // BBlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<8, 32, 1>, // B1BlockTransfer
// S<0, 2, 1>,
// S<0, 2, 1>,
// 1,
// 4,
// 2,
// false,
// 1, // CShuffleMXdlPerWavePerShuffle
// 4, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock, //
// CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec, // MaskingSpecialization
// Deterministic>;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
NumDimG
,
...
...
@@ -326,7 +395,7 @@ using DeviceGemmInstance =
TensorSpecY
,
1
,
256
,
128
,
// MPerBlock
64
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
128
,
// Gemm1NPerBlock
...
...
@@ -336,10 +405,10 @@ using DeviceGemmInstance =
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// MXdlPerWave
1
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
1
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp
0 → 100644
View file @
dded3450
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp
0 → 100644
View file @
dded3450
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment