Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1de7de06
Commit
1de7de06
authored
Jan 21, 2023
by
danyao12
Browse files
attn bwd kernel prototype1
parent
7409bc5d
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
4027 additions
and
0 deletions
+4027
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
+681
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+16
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+1163
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+2167
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
0 → 100644
View file @
1de7de06
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
1de7de06
...
@@ -859,6 +859,21 @@ struct BlockwiseGemmXdlops_v2
...
@@ -859,6 +859,21 @@ struct BlockwiseGemmXdlops_v2
"wrong!"
);
"wrong!"
);
}
}
__host__
__device__
BlockwiseGemmXdlops_v2
(
index_t
switch_flag
,
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
(),
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
())
:
switch_flag_
(
switch_flag
),
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
BlockwiseGemmXdlops_v2
(
const
BlockwiseGemmXdlops_v2
&
other
)
__host__
__device__
BlockwiseGemmXdlops_v2
(
const
BlockwiseGemmXdlops_v2
&
other
)
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
{
{
...
@@ -1126,6 +1141,7 @@ struct BlockwiseGemmXdlops_v2
...
@@ -1126,6 +1141,7 @@ struct BlockwiseGemmXdlops_v2
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
index_t
switch_flag_
;
AThreadCopy
a_thread_copy_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
0 → 100644
View file @
1de7de06
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
0 → 100644
View file @
1de7de06
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment