Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
94177eb6
Commit
94177eb6
authored
Jun 14, 2023
by
danyao12
Browse files
Merge branch 'attn-train-develop-qloop' into attn-train-develop-qloop-dropout-v2
parents
44f4498a
71e2a917
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
10 deletions
+10
-10
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+7
-7
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
94177eb6
...
...
@@ -98,9 +98,9 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecia
static
constexpr
bool
Deterministic
=
false
;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1
1st template
.
// If 32 < DIM <= 64 , ues prototype1
2nd template
.
// If 64 < DIM <= 128, ues prototype2
2nd template
.
// If DIM <= 32 , ues prototype1.
// If 32 < DIM <= 64 , ues prototype1.
// If 64 < DIM <= 128, ues prototype2.
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstance
=
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
94177eb6
...
...
@@ -62,9 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
InputDataType
=
B
F16
;
using
OutputDataType
=
F
32
;
using
GemmDataType
=
B
F16
;
using
InputDataType
=
F16
;
using
OutputDataType
=
F
16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
...
...
@@ -79,7 +79,7 @@ static constexpr ck::index_t NumDimK = 1;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
4
;
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
...
...
@@ -97,9 +97,9 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecia
static
constexpr
bool
Deterministic
=
false
;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1
1st template
.
// If 32 < DIM <= 64 , ues prototype1
2nd template
.
// If 64 < DIM <= 128, ues prototype2
2nd template
.
// If DIM <= 32 , ues prototype1.
// If 32 < DIM <= 64 , ues prototype1.
// If 64 < DIM <= 128, ues prototype2.
#if(DIM <= 32)
// clang-format off
using
DeviceGemmInstance
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment