Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3480e42b
Commit
3480e42b
authored
May 24, 2023
by
guangzlu
Browse files
fixed bugs after merge
parent
e439b369
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
7 deletions
+18
-7
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+9
-5
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v3.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v3.hpp
+9
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
3480e42b
...
...
@@ -105,7 +105,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
...
...
@@ -190,7 +190,8 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -249,7 +250,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
Deterministic
>
;
// MaskingSpecialization
#elif(DIM <= 64)
using
DeviceGemmInstanceFWD
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
<
...
...
@@ -329,7 +331,8 @@ using DeviceGemmInstanceBWD =
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
InputDataType
,
OutputDataType
,
GemmDataType
,
ZDataType
,
LSEDataType
,
...
...
@@ -388,7 +391,8 @@ using DeviceGemmInstanceBWD =
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
MaskingSpec
,
Deterministic
>
;
// MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v3.hpp
View file @
3480e42b
...
...
@@ -146,6 +146,9 @@ __global__ void
c0_matrix_mask
,
p_drop
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
}
}
...
...
@@ -178,6 +181,9 @@ __global__ void
c0_matrix_mask
,
p_drop
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
}
...
...
@@ -1007,8 +1013,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
c0_matrix_mask_
,
arg
.
p_drop_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment