Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
27482328
Commit
27482328
authored
Jan 29, 2023
by
fsx950223
Browse files
format code
parent
24f96b50
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
97 deletions
+121
-97
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
...oftmax_gemm/grouped_multihead_attention_backward_fp16.cpp
+93
-70
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+28
-27
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_fp16.cpp
View file @
27482328
...
...
@@ -393,7 +393,8 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
vgrad_tensors_device
;
std
::
size_t
group_count
=
3
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
){
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
4
+
1
);
int
N
=
128
*
(
rand
()
%
4
+
1
);
int
K
=
64
;
...
...
@@ -424,8 +425,8 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
//
pass
Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
...
...
@@ -463,7 +464,8 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
y_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
DataType
>
ygrad_gs_ms_os
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
if
(
i
<
4
){
if
(
i
<
4
)
{
std
::
cout
<<
"q_gs_ms_ks: "
<<
q_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_gs_ns_ks: "
<<
k_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_gs_os_ns: "
<<
v_gs_os_ns
.
mDesc
<<
std
::
endl
;
...
...
@@ -539,19 +541,24 @@ int run(int argc, char* argv[])
Tensor
<
DataType
>
y_g_m_o
({
BatchCount
,
M
,
O
});
Tensor
<
LSEDataType
>
lse_g_m
({
BatchCount
,
M
});
q_gs_ms_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
q_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
q_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
k_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
k_g_n_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
v_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
v_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
])
=
self
(
idx
);
});
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
alpha
,
s_g_m_n
,
p_g_m_n
,
y_g_m_o
,
lse_g_m
);
y_gs_ms_os
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
y_gs_ms_os
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
y_g_m_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
]);
});
lse_gs_ms
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_g_m
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
]);
});
...
...
@@ -567,15 +574,24 @@ int run(int argc, char* argv[])
y_tensors
.
push_back
(
y_gs_ms_os
);
lse_tensors
.
push_back
(
lse_gs_ms
);
ygrad_tensors
.
push_back
(
ygrad_gs_ms_os
);
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
v_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
y_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
k_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
v_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
y_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
lse_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
LSEDataType
)
*
lse_gs_ms
.
GetElementSpaceSize
()));
qgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
q_gs_ms_ks
.
GetElementSpaceSize
()));
kgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
k_gs_ns_ks
.
GetElementSpaceSize
()));
vgrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
v_gs_os_ns
.
GetElementSpaceSize
()));
ygrad_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DataType
)
*
y_gs_ms_os
.
GetElementSpaceSize
()));
q_tensors_device
.
back
()
->
ToDevice
(
q_gs_ms_ks
.
data
());
k_tensors_device
.
back
()
->
ToDevice
(
k_gs_ns_ks
.
data
());
v_tensors_device
.
back
()
->
ToDevice
(
v_gs_os_ns
.
data
());
...
...
@@ -595,8 +611,7 @@ int run(int argc, char* argv[])
p_ygrad
.
push_back
(
ygrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
p_qgrad
.
push_back
(
qgrad_tensors_device
.
back
()
->
GetDeviceBuffer
());
}
auto
argument
=
gemm
.
MakeArgument
(
p_q
,
auto
argument
=
gemm
.
MakeArgument
(
p_q
,
p_k
,
p_v
,
p_y
,
...
...
@@ -644,13 +659,15 @@ int run(int argc, char* argv[])
bool
pass
=
true
;
if
(
do_verification
)
{
for
(
int
i
=
0
;
i
<
group_count
;
i
++
){
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
qgrad_tensors_device
[
i
]
->
SetZero
();
kgrad_tensors_device
[
i
]
->
SetZero
();
vgrad_tensors_device
[
i
]
->
SetZero
();
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
){
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
G0
=
v_tensors
[
i
].
GetLengths
()[
0
];
int
G1
=
v_tensors
[
i
].
GetLengths
()[
1
];
...
...
@@ -695,13 +712,19 @@ int run(int argc, char* argv[])
ref_gemm_grad_invoker
.
Run
(
RefGemmGradArg
{
sgrad_g_n_m
,
q_g_m_ks
[
i
],
kgrad_g_n_k
,
PassThrough
{},
PassThrough
{},
Scale
{
alpha
}});
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
qgrad_gs_ms_ks_host_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_host_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_host_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
qgrad_gs_ms_ks_device_result
(
q_tensors
[
i
].
GetLengths
(),
q_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
kgrad_gs_ns_ks_device_result
(
k_tensors
[
i
].
GetLengths
(),
k_tensors
[
i
].
GetStrides
());
Tensor
<
DataType
>
vgrad_gs_os_ns_device_result
(
v_tensors
[
i
].
GetLengths
(),
v_tensors
[
i
].
GetStrides
());
qgrad_tensors_device
[
i
]
->
FromDevice
(
qgrad_gs_ms_ks_device_result
.
data
());
kgrad_tensors_device
[
i
]
->
FromDevice
(
kgrad_gs_ns_ks_device_result
.
data
());
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
27482328
...
...
@@ -36,7 +36,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
...
...
@@ -739,8 +739,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
auto
vgrad_grid_desc_n_o
=
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
problem_desc
.
b1_gs_gemm1ns_gemm1ks_lengths
,
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
ygrad_grid_desc_o0_m_o1
=
DeviceOp
::
MakeYGradGridDescriptor_O0_M_O1
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
const
auto
ygrad_grid_desc_o0_m_o1
=
DeviceOp
::
MakeYGradGridDescriptor_O0_M_O1
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
const
auto
a_grid_desc_g_m_k
=
Transform
::
MakeAGridDescriptor_G_M_K
(
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
...
...
@@ -889,8 +889,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
<
GridwiseGemm
,
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
<
GridwiseGemm
,
GroupKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
...
...
@@ -963,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
*
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
);
if
(
!
(
c_g
==
device_arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment