Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cb914a54
Commit
cb914a54
authored
Jan 25, 2023
by
ltqin
Browse files
move z matrix pos
parent
1306ae6b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
9 deletions
+23
-9
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
...emm/batched_multihead_attention_backward_fp16_dropout.cpp
+10
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+13
-4
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
View file @
cb914a54
...
@@ -248,12 +248,12 @@ int run(int argc, char* argv[])
...
@@ -248,12 +248,12 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
5
12
;
ck
::
index_t
M
=
12
8
;
ck
::
index_t
N
=
512
;
ck
::
index_t
N
=
256
;
ck
::
index_t
K
=
128
;
ck
::
index_t
K
=
128
;
ck
::
index_t
O
=
128
;
ck
::
index_t
O
=
128
;
ck
::
index_t
G0
=
3
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
2
;
ck
::
index_t
G1
=
1
;
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
float
alpha
=
1.
f
/
std
::
sqrt
(
K
);
...
@@ -363,7 +363,7 @@ int run(int argc, char* argv[])
...
@@ -363,7 +363,7 @@ int run(int argc, char* argv[])
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"lse_gs_ms_os: "
<<
lse_gs_ms
.
mDesc
<<
std
::
endl
;
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
-
1
});
z_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
0
});
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
break
;
...
@@ -475,6 +475,7 @@ int run(int argc, char* argv[])
...
@@ -475,6 +475,7 @@ int run(int argc, char* argv[])
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
//z_device_buf.SetZero();
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -545,6 +546,10 @@ int run(int argc, char* argv[])
...
@@ -545,6 +546,10 @@ int run(int argc, char* argv[])
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
//copy z matirx data form device
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
//std::cout << "z_g_m_n ref:\n" << z_g_m_n;
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
cb914a54
...
@@ -1441,7 +1441,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1441,7 +1441,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// z vgpr copy to global
// z vgpr copy to global
//
//
// z matrix threadwise desc
// z matrix threadwise desc
if
(
get_thread_global_1d_id
()
==
0
)
/*
if(get_thread_global_1d_id() == 0)
{
{
printf("m0: %d n0: %d m1: %d n1: %d m2: %d n2: %d n3: %d n4: %d \n",
printf("m0: %d n0: %d m1: %d n1: %d m2: %d n2: %d n3: %d n4: %d \n",
m0.value, // MRepeat
m0.value, // MRepeat
...
@@ -1452,7 +1452,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1452,7 +1452,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n2.value, // NGroupNum
n2.value, // NGroupNum
n3.value, // NInputNum
n3.value, // NInputNum
n4.value);
n4.value);
}
}
*/
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
...
@@ -1470,6 +1470,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1470,6 +1470,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
true
>
z_tenor_buffer
;
z_tenor_buffer
;
z_tenor_buffer
.
Clear
();
// z matrix global desc
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
...
@@ -1482,7 +1483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1482,7 +1483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
if
(
get_thread_global_1d_id
()
==
191
)
/*
if(get_thread_global_1d_id() == 191)
{
{
printf("wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}\n",
printf("wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}\n",
wave_id[I0],
wave_id[I0],
...
@@ -1490,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1490,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
wave_id[I2],
wave_id[I2],
wave_m_n_id[I0],
wave_m_n_id[I0],
wave_m_n_id[I1]);
wave_m_n_id[I1]);
}
}
*/
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ushort
,
ushort
,
...
@@ -2116,6 +2117,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2116,6 +2117,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
if
(
get_thread_global_1d_id
()
==
1
)
printf
(
"gemm1_k_block_outer_index: %d num_gemm1_k_block_outer_loop: %d
\n
"
,
gemm1_k_block_outer_index
,
num_gemm1_k_block_outer_loop
);
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle dQ and write
// shuffle dQ and write
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment