Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0286b9bf
Commit
0286b9bf
authored
Jun 25, 2023
by
danyao12
Browse files
adjust parameters to verify
parent
f1a49daf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
15 deletions
+15
-15
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
...ale_softmax_gemm/batched_multihead_attention_backward.cpp
+6
-6
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+3
-3
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train.cpp
..._scale_softmax_gemm/grouped_multihead_attention_train.cpp
+6
-6
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward.cpp
View file @
0286b9bf
...
@@ -25,7 +25,7 @@ Kernel outputs:
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -63,9 +63,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
...
@@ -63,9 +63,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
InputDataType
=
B
F16
;
using
InputDataType
=
F16
;
using
OutputDataType
=
F
32
;
using
OutputDataType
=
F
16
;
using
GemmDataType
=
B
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
...
@@ -80,7 +80,7 @@ static constexpr ck::index_t NumDimK = 1;
...
@@ -80,7 +80,7 @@ static constexpr ck::index_t NumDimK = 1;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
4
;
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
...
@@ -95,7 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
...
@@ -95,7 +95,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
0286b9bf
...
@@ -32,7 +32,7 @@ Kernel outputs:
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -89,7 +89,7 @@ static constexpr ck::index_t NumDimK = 1;
...
@@ -89,7 +89,7 @@ static constexpr ck::index_t NumDimK = 1;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
4
;
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
...
@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
...
@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train.cpp
View file @
0286b9bf
...
@@ -31,7 +31,7 @@ Kernel outputs:
...
@@ -31,7 +31,7 @@ Kernel outputs:
*/
*/
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -71,9 +71,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
...
@@ -71,9 +71,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using
QKVElementOp
=
PassThrough
;
using
QKVElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
YElementOp
=
PassThrough
;
using
InputDataType
=
B
F16
;
using
InputDataType
=
F16
;
using
OutputDataType
=
F
32
;
using
OutputDataType
=
F
16
;
using
GemmDataType
=
B
F16
;
using
GemmDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
ShuffleDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
...
@@ -88,7 +88,7 @@ static constexpr ck::index_t NumDimK = 1;
...
@@ -88,7 +88,7 @@ static constexpr ck::index_t NumDimK = 1;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
4
;
static
constexpr
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
=
8
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
#if USING_MASK
#if USING_MASK
...
@@ -103,7 +103,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
...
@@ -103,7 +103,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecK
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecV
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecY
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
// DIM should be a multiple of 8.
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If DIM <= 32 , ues prototype1 1st template.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment