Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
35b2971e
Commit
35b2971e
authored
Jul 26, 2023
by
danyao12
Browse files
fix bugs and optimize bwd qloop 2 kernels
parent
52478ac3
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
628 additions
and
541 deletions
+628
-541
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
..._softmax_gemm/batched_multihead_attention_backward_v3.cpp
+36
-43
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+38
-44
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+122
-119
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+123
-118
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+0
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+66
-51
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+71
-50
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+81
-53
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+84
-56
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
...gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
+4
-3
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v3.cpp
View file @
35b2971e
This diff is collapsed.
Click to expand it.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
35b2971e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
35b2971e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
35b2971e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
35b2971e
...
@@ -73,7 +73,7 @@ __global__ void
...
@@ -73,7 +73,7 @@ __global__ void
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -140,7 +140,7 @@ __global__ void
...
@@ -140,7 +140,7 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
@@ -175,7 +175,7 @@ __global__ void
...
@@ -175,7 +175,7 @@ __global__ void
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
lse_grid_desc_m
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
35b2971e
...
@@ -1040,7 +1040,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1040,7 +1040,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
return
ave_time
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
35b2971e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -36,7 +36,8 @@ __global__ void
...
@@ -36,7 +36,8 @@ __global__ void
kernel_grouped_multihead_attention_backward_ydotygrad_v1
(
kernel_grouped_multihead_attention_backward_ydotygrad_v1
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
)
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
cast_pointer_to_generic_address_space
(
group_kernel_args
));
cast_pointer_to_generic_address_space
(
group_kernel_args
));
...
@@ -89,12 +90,13 @@ template <typename GridwiseGemm,
...
@@ -89,12 +90,13 @@ template <typename GridwiseGemm,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1
(
kernel_grouped_multihead_attention_backward_
qloop_
xdl_cshuffle_
light_
v1
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -106,7 +108,8 @@ __global__ void
...
@@ -106,7 +108,8 @@ __global__ void
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
...
@@ -158,7 +161,7 @@ __global__ void
...
@@ -158,7 +161,7 @@ __global__ void
{
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -180,7 +183,6 @@ __global__ void
...
@@ -180,7 +183,6 @@ __global__ void
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
@@ -194,7 +196,7 @@ __global__ void
...
@@ -194,7 +196,7 @@ __global__ void
}
}
else
else
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -216,7 +218,6 @@ __global__ void
...
@@ -216,7 +218,6 @@ __global__ void
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
@@ -307,7 +308,7 @@ template <index_t NumDimG,
...
@@ -307,7 +308,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
struct
DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V1
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -320,7 +321,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -320,7 +321,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V1
;
struct
ProblemDesc
struct
ProblemDesc
{
{
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
...
@@ -341,9 +342,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -341,9 +342,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
d_gs_ms_lengths
;
std
::
vector
<
index_t
>
d_gs_ms_strides
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_strides
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_strides
;
...
@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
DMPerBlock
)
*
DMPerBlock
;
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
DMPerBlock
)
*
DMPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
GemmSpec
==
GemmSpecialization
::
MNPadding
||
make_tuple
(
Sequence
<
0
>
{}),
GemmSpec
==
GemmSpecialization
::
MKPadding
||
make_tuple
(
Sequence
<
0
>
{}));
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
d_grid_desc_mraw
;
}
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
YGridDesc_M_O
,
LSEGridDesc_M
,
LSEGridDesc_M
,
LSEGridDesc_M
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -725,14 +734,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -725,14 +734,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
// GridwiseYDotYGrad
// GridwiseYDotYGrad
using
GridwiseYDotYGrad
=
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
// TODO: distinguish A/B
DDataType
,
DDataType
,
// datatype
DYGridDesc_M_O
,
Y
GridDesc_M
_O
,
D
GridDesc_M
,
DGridDesc_M
,
BlockSize
,
Block
Size
,
DMPer
Block
,
DM
PerBlock
,
DK
PerBlock
,
DK
PerBlock
>
;
Gemm1N
PerBlock
>
;
using
DBlock2CTileMap
=
using
DBlock2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
>
;
OffsettedBlockToCTileMap
<
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
>
;
...
@@ -776,7 +785,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -776,7 +785,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameter
// D parameter
DDataType
*
p_d_grid_
;
DDataType
*
p_d_grid_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
DGridDesc_M
d_grid_desc_m_
;
DGridDesc_M
d_grid_desc_m_
;
DBlock2CTileMap
d_block_2_ctile_map_
;
DBlock2CTileMap
d_block_2_ctile_map_
;
typename
GridwiseYDotYGrad
::
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseYDotYGrad
::
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -956,7 +964,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -956,7 +964,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// D parameters
// D parameters
const
auto
p_d_grid
=
static_cast
<
DDataType
*>
(
p_Ds
[
i
]);
const
auto
p_d_grid
=
static_cast
<
DDataType
*>
(
p_Ds
[
i
]);
const
auto
d_grid_desc_m
=
const
auto
d_grid_desc_m
=
DeviceOp
::
MakeDGridDescriptor_M
(
problem_desc
.
d
_gs_ms_lengths
[
NumDimG
]);
DeviceOp
::
MakeDGridDescriptor_M
(
problem_desc
.
lse
_gs_ms_lengths
[
NumDimG
]);
const
auto
d_y_grid_desc_m_o
=
DTransform
::
MakeCGridDescriptor_M_N
(
const
auto
d_y_grid_desc_m_o
=
DTransform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
...
@@ -1001,7 +1009,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1001,7 +1009,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
raw_m_padded
,
raw_m_padded
,
raw_n_padded
,
raw_n_padded
,
p_d_grid
,
p_d_grid
,
d_y_grid_desc_m_o
,
d_grid_desc_m
,
d_grid_desc_m
,
d_block_2_ctile_map
,
d_block_2_ctile_map
,
d_y_grid_desc_mblock_mperblock_nblock_nperblock
,
d_y_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -1105,17 +1112,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1105,17 +1112,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
ave_time
=
launch_kernel
();
ave_time
=
launch_kernel
();
}
}
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1
<
const
auto
kernel
=
GridwiseGemm
,
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v1
<
GroupKernelArg
,
GridwiseGemm
,
AElementwiseOperation
,
GroupKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
AccElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
B1ElementwiseOperation
,
has_main_k_block_loop_
,
CElementwiseOperation
,
Deterministic
>
;
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -1139,11 +1148,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1139,11 +1148,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// to concern Gemm0's loop
// to concern Gemm0's loop
if
(
all_has_main_k_block_loop
)
if
(
all_has_main_k_block_loop
)
{
{
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
p_dropout_
>
0.0
)
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
if
(
!
some_has_main_k_block_loop
)
else
if
(
!
some_has_main_k_block_loop
)
{
{
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
p_dropout_
>
0.0
)
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
else
{
{
...
@@ -1169,22 +1188,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1169,22 +1188,18 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
for
(
index_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
{
{
// TODO: Check if tensor specialization & strides mismatch
// TODO: Check if tensor specialization & strides mismatch
const
auto
&
kernel_arg
=
arg
.
group_kernel_args_
[
i
];
const
auto
&
kernel_arg
=
arg
.
group_kernel_args_
[
i
];
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
kernel_arg
.
d_y_grid_desc_m_o_
,
kernel_arg
.
d_block_2_ctile_map_
))
{
return
false
;
}
// Check if C permute dimension matches GEMM + GEMM shape
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
...
@@ -1352,7 +1367,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1352,7 +1367,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1"
str
<<
"DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V1"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
35b2971e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -36,7 +36,8 @@ __global__ void
...
@@ -36,7 +36,8 @@ __global__ void
kernel_grouped_multihead_attention_backward_ydotygrad_v2
(
kernel_grouped_multihead_attention_backward_ydotygrad_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
)
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
cast_pointer_to_generic_address_space
(
group_kernel_args
));
cast_pointer_to_generic_address_space
(
group_kernel_args
));
...
@@ -89,12 +90,13 @@ template <typename GridwiseGemm,
...
@@ -89,12 +90,13 @@ template <typename GridwiseGemm,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
#endif
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
(
kernel_grouped_multihead_attention_backward_
qloop_
xdl_cshuffle_
light_
v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -106,7 +108,8 @@ __global__ void
...
@@ -106,7 +108,8 @@ __global__ void
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
const
auto
arg_ptr
=
reinterpret_cast
<
const
GroupKernelArg
*>
(
...
@@ -158,7 +161,7 @@ __global__ void
...
@@ -158,7 +161,7 @@ __global__ void
{
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -180,7 +183,6 @@ __global__ void
...
@@ -180,7 +183,6 @@ __global__ void
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
@@ -194,7 +196,7 @@ __global__ void
...
@@ -194,7 +196,7 @@ __global__ void
}
}
else
else
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -216,7 +218,6 @@ __global__ void
...
@@ -216,7 +218,6 @@ __global__ void
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
@@ -314,7 +315,7 @@ template <index_t NumDimG,
...
@@ -314,7 +315,7 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
struct
DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V2
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
:
public
BaseOperator
// TODO inherit atten bwd op once API stablizes
{
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
...
@@ -327,7 +328,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -327,7 +328,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: implement bias combination
// TODO: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
;
using
DeviceOp
=
DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V2
;
struct
ProblemDesc
struct
ProblemDesc
{
{
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
std
::
vector
<
index_t
>
a_gs_ms_ks_lengths
;
...
@@ -348,9 +349,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -348,9 +349,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
d_gs_ms_lengths
;
std
::
vector
<
index_t
>
d_gs_ms_strides
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_strides
;
std
::
vector
<
std
::
vector
<
index_t
>>
acc0_biases_gs_ms_ns_strides
;
...
@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -564,10 +562,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
DMPerBlock
)
*
DMPerBlock
;
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
DMPerBlock
)
*
DMPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
GemmSpec
==
GemmSpecialization
::
MNPadding
||
make_tuple
(
Sequence
<
0
>
{}),
GemmSpec
==
GemmSpecialization
::
MKPadding
||
make_tuple
(
Sequence
<
0
>
{}));
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
d_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
d_grid_desc_mraw
;
}
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -658,7 +668,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V2
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -680,7 +690,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
YGridDesc_M_O
,
YGridDesc_M_O
,
LSEGridDesc_M
,
LSEGridDesc_M
,
LSEGridDesc_M
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -733,14 +742,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -733,14 +742,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
// GridwiseYDotYGrad
// GridwiseYDotYGrad
using
GridwiseYDotYGrad
=
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
// TODO: distinguish A/B
DDataType
,
DDataType
,
// datatype
DYGridDesc_M_O
,
Y
GridDesc_M
_O
,
D
GridDesc_M
,
DGridDesc_M
,
BlockSize
,
Block
Size
,
DMPer
Block
,
DM
PerBlock
,
DK
PerBlock
,
DK
PerBlock
>
;
Gemm1N
PerBlock
>
;
using
DBlock2CTileMap
=
using
DBlock2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
>
;
OffsettedBlockToCTileMap
<
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
>
;
...
@@ -784,7 +793,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -784,7 +793,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameter
// D parameter
DDataType
*
p_d_grid_
;
DDataType
*
p_d_grid_
;
DYGridDesc_M_O
d_y_grid_desc_m_o_
;
DGridDesc_M
d_grid_desc_m_
;
DGridDesc_M
d_grid_desc_m_
;
DBlock2CTileMap
d_block_2_ctile_map_
;
DBlock2CTileMap
d_block_2_ctile_map_
;
typename
GridwiseYDotYGrad
::
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseYDotYGrad
::
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -870,6 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -870,6 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
index_t
z_random_matrix_offset
=
0
;
index_t
z_random_matrix_offset
=
0
;
d_grid_size_
=
0
;
d_grid_size_
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
for
(
index_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
...
@@ -920,6 +929,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -920,6 +929,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
grid_size_grp
=
const
index_t
grid_size_grp
=
(
Deterministic
?
1
:
block_2_ctile_map
.
CalculateGridSize
(
k_grid_desc_n_k
))
*
(
Deterministic
?
1
:
block_2_ctile_map
.
CalculateGridSize
(
k_grid_desc_n_k
))
*
...
@@ -959,7 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -959,7 +972,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// D parameters
// D parameters
const
auto
p_d_grid
=
static_cast
<
DDataType
*>
(
p_Ds
[
i
]);
const
auto
p_d_grid
=
static_cast
<
DDataType
*>
(
p_Ds
[
i
]);
const
auto
d_grid_desc_m
=
const
auto
d_grid_desc_m
=
DeviceOp
::
MakeDGridDescriptor_M
(
problem_desc
.
d
_gs_ms_lengths
[
NumDimG
]);
DeviceOp
::
MakeDGridDescriptor_M
(
problem_desc
.
lse
_gs_ms_lengths
[
NumDimG
]);
const
auto
d_y_grid_desc_m_o
=
DTransform
::
MakeCGridDescriptor_M_N
(
const
auto
d_y_grid_desc_m_o
=
DTransform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
...
@@ -1004,7 +1017,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1004,7 +1017,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
raw_m_padded
,
raw_m_padded
,
raw_n_padded
,
raw_n_padded
,
p_d_grid
,
p_d_grid
,
d_y_grid_desc_m_o
,
d_grid_desc_m
,
d_grid_desc_m
,
d_block_2_ctile_map
,
d_block_2_ctile_map
,
d_y_grid_desc_mblock_mperblock_nblock_nperblock
,
d_y_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -1107,17 +1119,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1107,17 +1119,19 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
ave_time
=
launch_kernel
();
ave_time
=
launch_kernel
();
}
}
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
<
const
auto
kernel
=
GridwiseGemm
,
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_light_v2
<
GroupKernelArg
,
GridwiseGemm
,
AElementwiseOperation
,
GroupKernelArg
,
BElementwiseOperation
,
AElementwiseOperation
,
AccElementwiseOperation
,
BElementwiseOperation
,
B1ElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
B1ElementwiseOperation
,
has_main_k_block_loop_
,
CElementwiseOperation
,
Deterministic
>
;
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -1141,11 +1155,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1141,11 +1155,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// to concern Gemm0's loop
// to concern Gemm0's loop
if
(
all_has_main_k_block_loop
)
if
(
all_has_main_k_block_loop
)
{
{
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
p_dropout_
>
0.0
)
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
if
(
!
some_has_main_k_block_loop
)
else
if
(
!
some_has_main_k_block_loop
)
{
{
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
p_dropout_
>
0.0
)
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
+=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
else
{
{
...
@@ -1171,7 +1195,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1171,7 +1195,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
@@ -1181,11 +1207,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1181,11 +1207,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// TODO: Check if tensor specialization & strides mismatch
// TODO: Check if tensor specialization & strides mismatch
const
auto
&
kernel_arg
=
arg
.
group_kernel_args_
[
i
];
const
auto
&
kernel_arg
=
arg
.
group_kernel_args_
[
i
];
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
if
(
!
GridwiseYDotYGrad
::
CheckValidity
(
kernel_arg
.
d_y_grid_desc_m_o_
,
kernel_arg
.
d_block_2_ctile_map_
))
{
return
false
;
}
// Check if C permute dimension matches GEMM + GEMM shape
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
const
index_t
c_m
=
kernel_arg
.
y_grid_desc_m_o_
.
GetLength
(
I0
);
...
@@ -1358,7 +1379,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1358,7 +1379,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2"
str
<<
"DeviceGroupedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V2"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
35b2971e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -41,7 +41,6 @@ template <typename InputDataType,
...
@@ -41,7 +41,6 @@ template <typename InputDataType,
typename
VGridDesc_O0_N_O1
,
typename
VGridDesc_O0_N_O1
,
typename
YGridDesc_M_O
,
typename
YGridDesc_M_O
,
typename
LSEGridDesc_M
,
typename
LSEGridDesc_M
,
typename
DGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -83,7 +82,7 @@ template <typename InputDataType,
...
@@ -83,7 +82,7 @@ template <typename InputDataType,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
struct
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V1
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -1155,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1155,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
YGradGridDesc_O0_M_O1
>
typename
YGradGridDesc_O0_M_O1
>
...
@@ -1180,7 +1180,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1180,7 +1180,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
VGridDesc_O0_N_O1
&
v_grid_desc_o0_n_o1
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
const
YGradGridDesc_O0_M_O1
&
ygrad_grid_desc_o0_m_o1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
@@ -1206,7 +1205,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1206,7 +1205,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
d_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
d_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_grid
,
d
_grid_desc_m
.
GetElementSpaceSize
());
p_d_grid
,
lse
_grid_desc_m
.
GetElementSpaceSize
());
// reuse lse grid descriptor
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_o0_m_o1
.
GetElementSpaceSize
());
p_ygrad_grid
,
ygrad_grid_desc_o0_m_o1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -1532,6 +1531,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1532,6 +1531,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I6
])};
acc0_thread_origin
[
I6
])};
auto
d_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatD
,
FloatGemmAcc
,
decltype
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
),
decltype
(
lse_thread_desc_mb_m0_m1_m2_m3_m4
),
Sequence
<
1
,
m0
,
m1
,
m2
,
m3
,
m4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
],
// mperxdl
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I6
])};
//
//
// z vgpr copy to global
// z vgpr copy to global
//
//
...
@@ -1651,11 +1669,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1651,11 +1669,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
// load d and lse
// load d and lse
//
//
lse
_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
d
_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
d_grid_buf
,
d_grid_buf
,
lse_thread_desc_mb_m0_m1_m2_m3_m4
,
lse_thread_desc_mb_m0_m1_m2_m3_m4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
y_dot_ygrad_thread_buf
);
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
lse_grid_buf
,
lse_grid_buf
,
...
@@ -1743,56 +1761,64 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1743,56 +1761,64 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
constexpr
auto
position_offset
=
M3
*
M4
;
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
// save z to global
if
(
p_z_grid
)
if
constexpr
(
IsDropout
)
{
{
if
(
p_z_grid
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
m_local
=
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_local
=
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
n_global
;
// unique element global 1d id
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
auto
global_elem_id
=
n_global
;
// unique element global 1d id
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
blockwise_dropout
true
>(
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
true
>(
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_thread_copy_vgpr_to_global
.
Run
(
z_grid_buf
);
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
}
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
else
z_tenor_buffer
,
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
ignore
=
z_grid_buf
;
z_grid_buf
);
}
else
{
ignore
=
z_grid_buf
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
m_local
=
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_local
=
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
n_global
;
// unique element global 1d id
auto
global_elem_id
=
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
decltype
(
position_offset
),
true
>(
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY)
// dS = P * (dP - Y_dot_dY)
...
@@ -1965,6 +1991,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
...
@@ -1965,6 +1991,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
d_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
// shuffle dK&dV and write
// shuffle dK&dV and write
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
35b2971e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -41,7 +41,6 @@ template <typename InputDataType,
...
@@ -41,7 +41,6 @@ template <typename InputDataType,
typename
VGridDesc_N0_O_N1
,
typename
VGridDesc_N0_O_N1
,
typename
YGridDesc_M_O
,
typename
YGridDesc_M_O
,
typename
LSEGridDesc_M
,
typename
LSEGridDesc_M
,
typename
DGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -91,7 +90,7 @@ template <typename InputDataType,
...
@@ -91,7 +90,7 @@ template <typename InputDataType,
bool
MaskOutUpperTriangle
,
bool
MaskOutUpperTriangle
,
bool
Deterministic
,
bool
Deterministic
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
struct
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_Light_V2
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -1110,6 +1109,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1110,6 +1109,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
YGradGridDesc_M0_O_M1
>
typename
YGradGridDesc_M0_O_M1
>
...
@@ -1135,7 +1135,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1135,7 +1135,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
VGridDesc_N0_O_N1
&
v_grid_desc_n0_o_n1
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
DGridDesc_M
&
d_grid_desc_m
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
YGradGridDesc_M0_O_M1
&
ygrad_grid_desc_m0_o_m1
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
...
@@ -1161,7 +1160,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1161,7 +1160,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
d_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
d_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_grid
,
d
_grid_desc_m
.
GetElementSpaceSize
());
p_d_grid
,
lse
_grid_desc_m
.
GetElementSpaceSize
());
// reuse lse grid descriptor
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -1516,6 +1515,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1516,6 +1515,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I6
])};
acc0_thread_origin
[
I6
])};
auto
d_thread_copy_global_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatD
,
FloatGemmAcc
,
decltype
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
),
decltype
(
lse_thread_desc_mb_m0_m1_m2_m3_m4
),
Sequence
<
1
,
m0
,
m1
,
m2
,
m3
,
m4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
,
1
,
true
/* ResetCoordAfterRun */
>
{
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
],
// mperxdl
acc0_thread_origin
[
I5
],
acc0_thread_origin
[
I6
])};
//
//
// z vgpr copy to global
// z vgpr copy to global
//
//
...
@@ -1612,11 +1630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1612,11 +1630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
// load d and lse
// load d and lse
//
//
lse
_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
d
_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
d_grid_buf
,
d_grid_buf
,
lse_thread_desc_mb_m0_m1_m2_m3_m4
,
lse_thread_desc_mb_m0_m1_m2_m3_m4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
y_dot_ygrad_thread_buf
);
y_dot_ygrad_thread_buf
);
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
lse_thread_copy_global_to_vgpr
.
Run
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
lse_grid_buf
,
lse_grid_buf
,
...
@@ -1706,55 +1724,63 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -1706,55 +1724,63 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
constexpr
auto
position_offset
=
M3
*
M4
;
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
// save z to global
if
(
p_z_grid
)
if
constexpr
(
IsDropout
)
{
{
if
(
p_z_grid
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
m_local
=
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_local
=
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
n_global
;
// unique element global 1d id
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
auto
global_elem_id
=
n_global
;
// unique element global 1d id
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
auto
global_elem_id
=
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
blockwise_dropout
true
>(
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
decltype
(
z_tenor_buffer
),
decltype
(
position_offset
),
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
true
>(
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_thread_copy_vgpr_to_global
.
Run
(
z_grid_buf
);
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
}
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
else
z_tenor_buffer
,
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
ignore
=
z_grid_buf
;
z_grid_buf
);
}
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
else
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
{
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
ignore
=
z_grid_buf
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
n_global
;
// unique element global 1d id
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
(
global_elem_id_raw
/
M4
)
*
M4
;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
position_offset
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
// gemm dV
// gemm dV
...
@@ -2005,6 +2031,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
...
@@ -2005,6 +2031,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step M
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
lse_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
d_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
lse_grid_desc_mb_m0_m1_m2_m3_m4
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
View file @
35b2971e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -26,7 +26,8 @@ template <typename InputDataType,
...
@@ -26,7 +26,8 @@ template <typename InputDataType,
typename
DGridDesc_M
,
typename
DGridDesc_M
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
>
index_t
NPerBlock
,
index_t
NPadded
>
struct
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
struct
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -103,7 +104,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
...
@@ -103,7 +104,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
MakeDefaultBlock2CTileMap
(
const
YGridDesc_M_N
&
y_grid_desc_m_n
)
MakeDefaultBlock2CTileMap
(
const
YGridDesc_M_N
&
y_grid_desc_m_n
)
{
{
// should rewrite BlockToCTileMap_M00_N0_M01Adapt
// should rewrite BlockToCTileMap_M00_N0_M01Adapt
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
1024
,
YGridDesc_M_N
>
(
y_grid_desc_m_n
);
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPadded
,
YGridDesc_M_N
>
(
y_grid_desc_m_n
);
}
}
using
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
YGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment