Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0cf81fde
Commit
0cf81fde
authored
Dec 27, 2022
by
Anthony Chang
Browse files
rename
parent
394f9207
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
6 deletions
+8
-6
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+1
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
0cf81fde
...
...
@@ -33,7 +33,7 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
gemm_softmax_gemm_permute
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_
multihead_attention_backward
_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_
gemm_softmax_gemm_permute
_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_batched_
multihead_attention_backward
_xdl_cshuffle.hpp
View file @
0cf81fde
...
...
@@ -9,11 +9,13 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
gemm_softmax_gemm
_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_
multihead_attention_backward
_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -206,7 +208,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
:
public
BaseOperator
// TODO inherit atten bwd op
:
public
BaseOperator
// TODO inherit atten bwd op
once API stablizes
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
...
...
@@ -552,7 +554,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatched
GemmSoftmaxGemm
_Xdl_CShuffle
<
using
GridwiseGemm
=
GridwiseBatched
MultiheadAttentionBackward
_Xdl_CShuffle
<
DataType
,
// TODO: distinguish A/B datatype
LSEDataType
,
GemmAccDataType
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_
gemm_softmax_gemm
_xdl_cshuffle_v1.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_batched_
multihead_attention_backward
_xdl_cshuffle_v1.hpp
View file @
0cf81fde
...
...
@@ -80,7 +80,7 @@ template <typename DataType,
bool
PadN
,
bool
MaskOutUpperTriangle
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatched
GemmSoftmaxGemm
_Xdl_CShuffle
struct
GridwiseBatched
MultiheadAttentionBackward
_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment