Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7af1b43a
Commit
7af1b43a
authored
May 19, 2023
by
ltqin
Browse files
one block dv pass
parent
b281dac5
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
115 deletions
+115
-115
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_V2R2.cpp
...oftmax_gemm/batched_multihead_attention_backward_V2R2.cpp
+11
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
...tched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
+102
-103
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+1
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
7af1b43a
...
@@ -12,7 +12,7 @@ add_example_executable(example_batched_multihead_attention_backward batched_mult
...
@@ -12,7 +12,7 @@ add_example_executable(example_batched_multihead_attention_backward batched_mult
add_example_executable
(
example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp
)
add_example_executable
(
example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_train batched_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_train batched_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_V
1
R2 batched_multihead_attention_backward_V2R2.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_V
2
R2 batched_multihead_attention_backward_V2R2.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_V2R2.cpp
View file @
7af1b43a
...
@@ -350,9 +350,9 @@ using DeviceGemmInstance =
...
@@ -350,9 +350,9 @@ using DeviceGemmInstance =
2
,
// B1K1
2
,
// B1K1
32
,
// MPerXDL
32
,
// MPerXDL
32
,
// NPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
1
,
// MXdlPerWave
1
,
// NXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -375,8 +375,8 @@ using DeviceGemmInstance =
...
@@ -375,8 +375,8 @@ using DeviceGemmInstance =
4
,
4
,
2
,
2
,
false
,
false
,
4
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
...
@@ -501,17 +501,17 @@ int run(int argc, char* argv[])
...
@@ -501,17 +501,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
5
12
;
ck
::
index_t
M
=
12
8
;
ck
::
index_t
N
=
5
12
;
ck
::
index_t
N
=
12
8
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
6
;
ck
::
index_t
G1
=
1
;
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0
.2
;
float
p_drop
=
0
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -1040,6 +1040,7 @@ int run(int argc, char* argv[])
...
@@ -1040,6 +1040,7 @@ int run(int argc, char* argv[])
"error"
,
"error"
,
1e-2
,
1e-2
,
1e-2
);
1e-2
);
//std::cout << vgrad_gs_os_ns_device_result << std::endl;
}
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
View file @
7af1b43a
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
7af1b43a
...
@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
{
block_sync_lds
();
block_sync_lds
();
}
}
do
do
{
{
auto
n_block_data_idx_on_grid
=
auto
n_block_data_idx_on_grid
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment