Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d4256471
"tests/test_models/vscode:/vscode.git/clone" did not exist on "e013bab5674e8d35d1998a050e1fa239ac9a747d"
Commit
d4256471
authored
Aug 22, 2023
by
letaoqin
Browse files
single block
parent
763e26be
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
134 additions
and
120 deletions
+134
-120
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+37
-34
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+2
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+95
-81
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
d4256471
...
@@ -273,12 +273,12 @@ int run(int argc, char* argv[])
...
@@ -273,12 +273,12 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
512
;
ck
::
index_t
M
=
64
;
ck
::
index_t
N
=
5
12
;
ck
::
index_t
N
=
12
8
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
6
;
ck
::
index_t
G1
=
1
;
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
...
@@ -395,7 +395,7 @@ int run(int argc, char* argv[])
...
@@ -395,7 +395,7 @@ int run(int argc, char* argv[])
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
q_gs_ms_ks
(
q_gs_ms_ks_lengths
,
q_gs_ms_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
InputDataType
>
k_gs_ns_ks
(
k_gs_ns_ks_lengths
,
k_gs_ns_ks_strides
);
Tensor
<
Z
DataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
Acc0Bias
DataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
v_gs_os_ns
(
v_gs_os_ns_lengths
,
v_gs_os_ns_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
InputDataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
...
@@ -404,6 +404,7 @@ int run(int argc, char* argv[])
...
@@ -404,6 +404,7 @@ int run(int argc, char* argv[])
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"z_gs_ms_ns: "
<<
z_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"y_gs_ms_os: "
<<
y_gs_ms_os
.
mDesc
<<
std
::
endl
;
...
@@ -414,13 +415,12 @@ int run(int argc, char* argv[])
...
@@ -414,13 +415,12 @@ int run(int argc, char* argv[])
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
// q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
2
,
2
});
//
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{
0
});
//d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{
1
});
break
;
break
;
case
2
:
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
...
@@ -469,7 +469,7 @@ int run(int argc, char* argv[])
...
@@ -469,7 +469,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
//
dy[g0,
g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
//dy[g0,g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// P = softmax(QK) = 0.0039 * ones
...
@@ -481,26 +481,6 @@ int run(int argc, char* argv[])
...
@@ -481,26 +481,6 @@ int run(int argc, char* argv[])
// = 0
// = 0
}
}
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
d_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// qkv gradients have the same descriptor as with qkv
// qkv gradients have the same descriptor as with qkv
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
q_device_buf
(
sizeof
(
InputDataType
)
*
q_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
k_device_buf
(
sizeof
(
InputDataType
)
*
k_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
...
@@ -517,7 +497,6 @@ int run(int argc, char* argv[])
...
@@ -517,7 +497,6 @@ int run(int argc, char* argv[])
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
...
@@ -630,15 +609,39 @@ int run(int argc, char* argv[])
...
@@ -630,15 +609,39 @@ int run(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
// copy z matirx data form device
Tensor
<
InputDataType
>
q_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
InputDataType
>
k_g_n_k
({
BatchCount
,
N
,
K
});
Tensor
<
Acc0BiasDataType
>
d_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
ZDataType
>
z_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
v_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
s_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
p_drop_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
InputDataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
z_device_buf
.
FromDevice
(
z_gs_ms_ns
.
mData
.
data
());
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
d_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// run fwd again for y, cause z_g_m_n update
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host
(
q_g_m_k
,
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
k_g_n_k
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
d4256471
...
@@ -118,7 +118,7 @@ __global__ void
...
@@ -118,7 +118,7 @@ __global__ void
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
const
D0DataType
*
tmp_p_d0_grid
=
nullptr
;
const
D0DataType
*
tmp_p_d0_grid
=
p_d0_grid
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
...
@@ -850,10 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -850,10 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_drop_
{
p_drop
}
p_drop_
{
p_drop
}
{
{
// TODO: implement bias addition
// TODO: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ns_lengths
;
ignore
=
acc0_biases_gs_ms_ns_strides
;
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
...
@@ -1042,7 +1039,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1042,7 +1039,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
has_main_k_block_loop_
,
has_main_k_block_loop_
,
is_dropout_
,
is_dropout_
,
Deterministic
>
;
Deterministic
>
;
std
::
cout
<<
"device address : "
<<
arg
.
p_d0_grid_
<<
std
::
endl
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
kernel
,
kernel
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
d4256471
...
@@ -123,9 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -123,9 +123,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
// D0
// D0
static
constexpr
auto
D0M
3
=
Number
<
2
>
{};
static
constexpr
auto
D0M
1
=
Number
<
4
>
{};
static
constexpr
auto
D0M
2
=
Number
<
MPer
Xdl
/
D0M
3
.
value
>
{};
static
constexpr
auto
D0M
0
=
Number
<
MPer
Block
/
D0M
1
.
value
>
{};
static
constexpr
auto
D0M1
=
Number
<
MPerBlock
/
MPerXdl
>
{};
//
static constexpr auto D0M1 = Number<MPerBlock / MPerXdl>{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
...
@@ -1160,18 +1160,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1160,18 +1160,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
//
const auto M = d0_grid_desc_m_n.GetLength(I0);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
//
const auto N = d0_grid_desc_m_n.GetLength(I1);
const
auto
MBlock
=
M
/
MPerBlock
;
//
const auto MBlock = M / MPerBlock;
const
auto
NBlock
=
N
/
NPerBlock
;
//
const auto NBlock = N / NPerBlock;
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M1
,
D0M2
,
D0M
3
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
D0M0
,
D0M
1
)),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_unmerge_transform
(
make_tuple
(
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
3
,
5
>
{},
Sequence
<
1
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
}
}
...
@@ -1184,28 +1184,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1184,28 +1184,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3
()
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3
()
{
{
// B1 matrix in LDS memory, dst of blockwise copy
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
D0M0
,
Number
<
NPerBlock
>
{},
D0M1
),
make_tuple
(
I1
,
I1
,
I1
,
D0M2
,
Number
<
NPerBlock
>
{},
D0M3
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M1
,
D0M1
,
I1
));
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M3
,
Number
<
NPerBlock
>
{}
*
D0M3
,
Number
<
NPerBlock
>
{}
*
D0M3
,
Number
<
NPerBlock
>
{}
*
D0M3
,
D0M3
,
I1
));
}
}
__host__
__device__
static
constexpr
auto
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
()
__host__
__device__
static
constexpr
auto
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
()
{
{
constexpr
auto
d0_raw_m0_n_m1
=
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor
(
make_tuple
(
D0M
2
,
Number
<
NPerBlock
>
{},
D0M
3
),
make_naive_tensor_descriptor
(
make_tuple
(
D0M
0
,
Number
<
NPerBlock
>
{},
D0M
1
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M
3
,
D0M
3
,
I1
));
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M
1
,
D0M
1
,
I1
));
constexpr
auto
d0_n0_n1_m0_m1_m2_m3
=
transform_tensor_descriptor
(
constexpr
auto
d0_n0_n1_m0_m1_m2_m3
=
transform_tensor_descriptor
(
d0_raw_m0_n_m1
,
d0_raw_m0_n_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
(
D0M
2
*
D0M3
)
/
I8
,
I2
,
I4
/
D0M3
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
D0M
0
/
I2
,
I2
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
Number
<
NPerBlock
/
NPerXdl
>
{},
Number
<
NPerXdl
>
{})),
make_tuple
(
Number
<
NPerBlock
/
NPerXdl
>
{},
Number
<
NPerXdl
>
{})),
make_pass_through_transform
(
D0M
3
)),
make_pass_through_transform
(
D0M
1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
2
,
3
,
4
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
2
,
3
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
4
>
{}));
return
d0_n0_n1_m0_m1_m2_m3
;
return
d0_n0_n1_m0_m1_m2_m3
;
}
}
static
constexpr
auto
d0_block_desc_m0_n0_m1_m2_n1_m3
=
static
constexpr
auto
d0_block_desc_m0_n0_m1_m2_n1_m3
=
...
@@ -1214,29 +1208,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1214,29 +1208,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
();
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
();
static
constexpr
auto
d0_thread_desc_
=
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I
4
,
I1
,
I4
/
D0M3
,
D0M3
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I
8
,
I1
,
D0M1
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
1
,
1
,
D0M
2
,
NPerBlock
,
D0M
3
>
,
// BlockSliceLengths
Sequence
<
D0M
0
,
NPerBlock
,
D0M
1
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
1
,
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// ThreadClusterArrangeOrder
Sequence
<
0
,
2
,
1
>
,
// ThreadClusterArrangeOrder
D0DataType
,
// SrcData
D0DataType
,
// SrcData
D0DataType
,
// DstData
D0DataType
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
decltype
(
d0_block_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
1
,
2
,
3
,
5
,
4
>
,
// SrcDimAccessOrder
Sequence
<
0
,
2
,
1
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
2
,
4
,
3
,
5
>
,
// DstDimAccessOrder
Sequence
<
1
,
0
,
2
>
,
// DstDimAccessOrder
4
,
// SrcVectorDim
1
,
// SrcVectorDim
2
,
// DstVectorDim
2
,
// DstVectorDim
NPerBlock
/
32
,
// SrcScalarPerVector
1
,
// SrcScalarPerVector
D0M3
.
value
/
1
,
// DstScalarPerVector
1
,
// DstScalarPerVector
1
,
1
,
1
,
1
,
fals
e
,
tru
e
,
true
,
// DstResetCoord
true
,
// DstResetCoord
1
>
;
1
>
;
...
@@ -1245,11 +1239,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1245,11 +1239,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0DataType
,
// DstData
D0DataType
,
// DstData
decltype
(
d0_block_desc_n0_n1_m0_m1_m2_m3
),
// SrcDesc
decltype
(
d0_block_desc_n0_n1_m0_m1_m2_m3
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
2
,
2
>
,
// SliceLengths
Sequence
<
1
,
1
,
8
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// DimAccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
5
,
// SrcVectorDim
4
,
// SrcVectorDim
D0M3
.
value
,
// SrcScalarPerVector
1
,
// SrcScalarPerVector
D0M3
.
value
>
;
1
>
;
};
};
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
...
@@ -1739,13 +1733,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1739,13 +1733,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// D0
// D0
auto
d0_block_copy_global_to_lds
=
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
,
0
));
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
ignore
=
d0_thread_copy_lds_to_vgpr
;
ignore
=
d0_thread_copy_lds_to_vgpr
;
//
//
// set up Y dot dY
// set up Y dot dY
...
@@ -1922,6 +1916,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1922,6 +1916,22 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
s_slash_p_thread_buf
,
s_slash_p_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// {
// auto p_lds = static_cast<D0DataType*>(p_shared);
// for(int idx = 0; idx < 32; idx++)
// {
// float tmp_gbl = ck::type_convert<float>(p_d0_grid[idx]);
// float tmp_lds = ck::type_convert<float>(p_lds[idx]);
// block_sync_lds();
// printf("p_d0_grid: %p, index: %d, gbl[idx]: %f, lds[idx]: %f \n",
// ck::type_convert<const void*>(p_d0_grid),
// idx,
// tmp_gbl,
// tmp_lds);
// }
// }
// 8d thread_desc in thread scope
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
...
@@ -1983,61 +1993,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1983,61 +1993,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
{
static
constexpr
auto
c_thread_desc_
=
//
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
16
>
{}));
//
make_naive_tensor_descriptor_packed(make_tuple(
I2
, Number<16>{}));
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Gemm
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
D0
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
ignore
=
d0_thread_buf
;
ignore
=
d0_thread_buf
;
static_for
<
0
,
D0M1
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
//
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
//
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
d0_block_buf
);
block_sync_lds
();
block_sync_lds
();
// read data form lds
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_desc_n0_n1_m0_m1_m2_m3
,
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_desc_n0_n1_m0_m1_m2_m3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
d0_block_buf
,
D0Loader
::
d0_thread_desc_
,
D0Loader
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_buf
);
// static_for<0, 2, 1>{}([&](auto mr) {
// bias add
// bias add
static_for
<
0
,
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
(),
1
>
{}(
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
[
&
](
auto
i
)
{
// constexpr index_t c_offset =
constexpr
index_t
c_offset
=
// c_thread_desc_.CalculateOffset(make_tuple(mr, i));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
mr
,
i
));
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f)
// {
// {
// printf("c_offset: %d s_slash_p_thread_buf(Number<c_offset>{}):
// float tmp_lds =
// %f, "
// ck::type_convert<float>(static_cast<D0DataType*>(p_shared)[i.value]);
// "d0_thread_buf[i]: %f\n",
// float tmp_gbl = ck::type_convert<float>(p_d0_grid[i.value]);
// c_offset,
// block_sync_lds();
// s_slash_p_thread_buf(Number<c_offset>{}),
// printf("id: %d : i: %d, gbl[i]: %f "
// ",lds[i]: %f d0_thread_buf[i]: %f\n",
// get_thread_local_1d_id(),
// i.value,
// tmp_gbl,
// tmp_lds,
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// }
// }
s_slash_p_thread_buf
(
i
)
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
});
});
});
//
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
// d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
1
,
0
,
-
D0M1
.
value
,
0
,
0
,
0
));
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0,
// 0));
}
}
// P_i: = softmax(scalar * S_i:)
// P_i: = softmax(scalar * S_i:)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment