Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
807ac476
"src/include/blockwise_4d_tensor_op.hpp" did not exist on "b2439ec9dd8acc7a6788c3225fda80eb7f416ce6"
Commit
807ac476
authored
Feb 27, 2023
by
ltqin
Browse files
add Gemm2NXdlPerWave template parameter
parent
b5fbb74b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
2 deletions
+7
-2
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+3
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
807ac476
...
...
@@ -130,6 +130,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -198,6 +199,7 @@ using DeviceGemmInstance =
1
,
// MXdlPerWave
4
,
// NXdlPerWave
2
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
View file @
807ac476
...
...
@@ -203,6 +203,7 @@ template <index_t NumDimG,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -627,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
Gemm2NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
807ac476
...
...
@@ -52,6 +52,7 @@ template <typename DataType,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
index_t
Gemm2NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -662,9 +663,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static
constexpr
index_t
BSrcVectorDim
=
1
;
// Free1_O dimension
static
constexpr
index_t
BSrcScalarPerVector
=
4
;
static
constexpr
index_t
GemmNWave
=
2
;
static
constexpr
index_t
GemmNWave
=
Free0_N
/
Gemm2NXdlPerWave
/
MPerXdl
;
static
constexpr
index_t
GemmOWave
=
BlockSize
/
get_warp_size
()
/
GemmNWave
;
static
constexpr
index_t
GemmNRepeat
=
Free0_N
/
GemmNWave
/
MPerXdl
;
static
constexpr
index_t
GemmNRepeat
=
Gemm2NXdlPerWave
;
static
constexpr
index_t
GemmORepeat
=
Free1_O
/
GemmOWave
/
NPerXdl
;
static
constexpr
index_t
GemmMPack
=
math
::
max
(
math
::
lcm
(
A_M1
,
B_M1
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment