Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ff6e303d
Commit
ff6e303d
authored
May 17, 2023
by
danyao12
Browse files
remove useless templates
parent
1970c4d2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
48 deletions
+6
-48
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
..._softmax_gemm/batched_multihead_attention_backward_v4.cpp
+2
-16
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
+4
-24
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
+0
-8
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v4.cpp
View file @
ff6e303d
...
...
@@ -62,8 +62,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
DataType
=
B
F16
;
using
GemmDataType
=
B
F16
;
using
DataType
=
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
...
...
@@ -151,13 +151,6 @@ using DeviceGemmInstance =
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
64
,
1
,
4
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -219,13 +212,6 @@ using DeviceGemmInstance =
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
View file @
ff6e303d
...
...
@@ -216,13 +216,6 @@ template <index_t NumDimG,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -644,14 +637,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -1019,14 +1004,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
auto
Gemm1NzRaw
=
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
NzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
NzRaw
:
Gemm1NzRaw
;
const
auto
c_extent_lowest
=
Gemm1NzRaw
;
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b_extent_lowest
=
BBlockTransferSrcVectorDim
==
2
?
KzRaw
:
NzRaw
;
const
auto
c_extent_lowest
=
Gemm1NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b_extent_lowest
%
BBlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
...
...
@@ -1037,13 +1020,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
const
auto
b_stride_lowest
=
BBlockTransferSrcVectorDim
==
2
?
arg
.
b_nz_kz_strides_
[
1
]
:
arg
.
b_nz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_kz_strides_
[
1
]
:
arg
.
b1_nz_kz_strides_
[
0
];
const
auto
c_stride_lowest
=
arg
.
c_mz_gemm1nz_strides_
[
1
];
// cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
if
(
!
(
a_stride_lowest
==
1
||
b_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp
View file @
ff6e303d
...
...
@@ -70,14 +70,6 @@ template <typename DataType,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
index_t
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment