Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ac407086
Commit
ac407086
authored
Feb 27, 2023
by
ltqin
Browse files
change k=64 config
parent
a465a936
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
6 deletions
+6
-6
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+1
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
ac407086
...
@@ -25,7 +25,7 @@ Kernel outputs:
...
@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 1
#define USING_MASK 1
#define USING_K128
1
#define USING_K128
0
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -213,7 +213,7 @@ using DeviceGemmInstance =
...
@@ -213,7 +213,7 @@ using DeviceGemmInstance =
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
2
,
2
,
2
,
false
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
...
@@ -340,8 +340,8 @@ int run(int argc, char* argv[])
...
@@ -340,8 +340,8 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
O
=
64
;
#endif
#endif
ck
::
index_t
G0
=
54
;
ck
::
index_t
G0
=
3
;
ck
::
index_t
G1
=
16
;
ck
::
index_t
G1
=
2
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
View file @
ac407086
...
@@ -49,7 +49,7 @@ template <typename GridwiseGemm,
...
@@ -49,7 +49,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*
CK_MIN_BLOCK_PER_CU
*/
1
)
#endif
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
ac407086
...
@@ -90,7 +90,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -90,7 +90,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
{
using
type
=
T
;
using
type
=
T
;
};
};
#if defined(__gfx90a__)
#if defined(__gfx90a_
masking_
_)
template
<
>
template
<
>
struct
TypeMap
<
ck
::
half_t
>
struct
TypeMap
<
ck
::
half_t
>
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment