Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d37c1d0b
Commit
d37c1d0b
authored
May 23, 2023
by
guangzlu
Browse files
dim=32 pass now
parent
5d90769e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
53 deletions
+88
-53
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
..._scale_softmax_gemm/batched_multihead_attention_train.cpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
+48
-27
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+35
-23
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train.cpp
View file @
d37c1d0b
...
@@ -32,7 +32,7 @@ Kernel outputs:
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
32
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -730,7 +730,7 @@ int run(int argc, char* argv[])
...
@@ -730,7 +730,7 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.
0
;
float
p_drop
=
0.
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -1040,7 +1040,8 @@ int run(int argc, char* argv[])
...
@@ -1040,7 +1040,8 @@ int run(int argc, char* argv[])
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
qgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
float
ave_time_bwd
=
invoker_bwd
.
Run
(
argument_bwd
,
StreamConfig
{
nullptr
,
true
});
float
ave_time_bwd
=
invoker_bwd
.
Run
(
argument_bwd
,
StreamConfig
{
nullptr
,
true
});
...
@@ -1149,6 +1150,7 @@ int run(int argc, char* argv[])
...
@@ -1149,6 +1150,7 @@ int run(int argc, char* argv[])
std
::
ofstream
fwd_file
(
"./z_fwd_matrix_txt"
);
std
::
ofstream
fwd_file
(
"./z_fwd_matrix_txt"
);
fwd_file
<<
z_fwd_gs_ms_ns
<<
std
::
endl
;
fwd_file
<<
z_fwd_gs_ms_ns
<<
std
::
endl
;
qgrad_device_buf
.
SetZero
();
kgrad_device_buf
.
SetZero
();
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt3.hpp
View file @
d37c1d0b
...
@@ -1525,7 +1525,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1525,7 +1525,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -1561,7 +1561,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1561,7 +1561,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence
<
I1
,
// MBlockId
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockId
I1
,
// NBlockId
m0
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -1984,33 +1984,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1984,33 +1984,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
global_elem_id
=
auto
global_elem_id
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
index_t
id_step
=
Acc0TileIterator
::
GetNumOfAccess
()
/
n0
.
value
;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
true
,
true
>(
decltype
(
n0
),
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
);
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
global_elem_id
+
id_step
*
i
.
value
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
// P_dropped
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
// static_for<0, n0, 1>{}([&](auto i) {
});
// blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
// decltype(z_tenor_buffer),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
// true,
make_multi_index
(
0
,
0
,
0
,
-
n0
.
value
,
0
,
0
,
0
,
0
,
0
,
0
));
// decltype(n0),
// decltype(i)>(s_slash_p_thread_buf,
// ph,
// global_elem_id + id_step
// * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
d37c1d0b
...
@@ -860,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -860,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -891,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -891,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence
<
I1
,
// MBlockId
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -1067,18 +1067,15 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1067,18 +1067,15 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// printf("at 1 global_elem_id is %d \n", global_elem_id);
// printf("at 1 global_elem_id is %d \n", global_elem_id);
// }
// }
index_t
id_step
=
Acc0TileIterator
::
GetNumOfAccess
()
/
n0
.
value
;
//
index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
// save z to global
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
false
,
false
>(
decltype
(
n0
),
acc_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
);
decltype
(
i
)>(
acc_thread_buf
,
ph
,
global_elem_id
+
i
.
value
*
id_step
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -1086,13 +1083,28 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1086,13 +1083,28 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
// static_for<0, n0, 1>{}([&](auto i) {
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
// blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
// decltype(z_tenor_buffer),
});
// false,
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
// decltype(n0),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
// decltype(i)>(
make_multi_index
(
0
,
0
,
0
,
-
(
n0
.
value
),
0
,
0
,
0
,
0
,
0
,
0
));
// acc_thread_buf, ph, global_elem_id + id_step * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment