Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cf9ef868
Commit
cf9ef868
authored
Jun 21, 2023
by
ltqin
Browse files
remove useless code
parent
118742b6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
219 deletions
+38
-219
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v6.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v6.hpp
+12
-48
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v7.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v7.hpp
+12
-48
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
...dwise_batched_multihead_attention_bacckward_ydotygrad.hpp
+14
-123
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v6.hpp
View file @
cf9ef868
...
...
@@ -34,8 +34,7 @@ template <typename GridwiseGemm,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
ORSGridDescriptor_M
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
bool
Deterministic
>
typename
ComputeBasePtrOfStridedBatch
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
...
@@ -49,7 +48,6 @@ __global__ void
const
ORSGridDescriptor_M
ors_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -65,22 +63,6 @@ __global__ void
const
long_index_t
ors_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
GridwiseGemm
::
template
Run
(
p_y_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_ors_grid
+
ors_batch_offset
,
p_shared
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
ors_grid_desc_m
,
block_2_ctile_map
,
i
);
}
}
else
{
// GridwiseGemm::test();
GridwiseGemm
::
Run
(
p_y_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
...
...
@@ -88,9 +70,8 @@ __global__ void
p_shared
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
ors_grid_desc_m
,
block_2_ctile_map
,
0
);
}
block_2_ctile_map
);
#else
ignore
=
p_y_grid
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -771,27 +752,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
// TODO: distinguish A/B
// datatype
GemmDataType
,
GemmAccDataType
,
ORSDataType
,
YGridDesc_M_O
,
ORSGridDesc_M
,
BlockSize
,
128
,
128
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
32
,
32
,
1
,
4
,
ABlockLdsExtraM
,
BBlockLdsExtraN
,
Deterministic
>
;
32
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -1061,8 +1027,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
ORSGridDesc_M
,
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
Deterministic
>
;
ComputeBasePtrOfStridedBatch
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -1077,7 +1042,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
ors_grid_desc_m_
,
arg
.
ors_block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
ors_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
),
arg
.
compute_base_ptr_of_batch_
);
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v7.hpp
View file @
cf9ef868
...
...
@@ -33,8 +33,7 @@ template <typename GridwiseGemm,
typename
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
typename
ORSGridDescriptor_M
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
bool
Deterministic
>
typename
ComputeBasePtrOfStridedBatch
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
...
@@ -48,7 +47,6 @@ __global__ void
const
ORSGridDescriptor_M
ors_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
nblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -64,22 +62,6 @@ __global__ void
const
long_index_t
ors_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
if
constexpr
(
Deterministic
)
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
GridwiseGemm
::
template
Run
(
p_y_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_ors_grid
+
ors_batch_offset
,
p_shared
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
ors_grid_desc_m
,
block_2_ctile_map
,
i
);
}
}
else
{
// GridwiseGemm::test();
GridwiseGemm
::
Run
(
p_y_grid
+
c_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
...
...
@@ -87,9 +69,8 @@ __global__ void
p_shared
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
ors_grid_desc_m
,
block_2_ctile_map
,
0
);
}
block_2_ctile_map
);
#else
ignore
=
p_y_grid
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -787,27 +768,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using
GridwiseYDotYGrad
=
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
<
InputDataType
,
// TODO: distinguish A/B
// datatype
GemmDataType
,
GemmAccDataType
,
ORSDataType
,
YGridDesc_M_O
,
ORSGridDesc_M
,
BlockSize
,
256
,
128
,
KPerBlock
,
32
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
64
,
64
,
1
,
4
,
ABlockLdsExtraM
,
BBlockLdsExtraN
,
Deterministic
>
;
64
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -1076,8 +1042,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
DeviceOp
::
ORSGridDesc_M
,
typename
GridwiseYDotYGrad
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
Deterministic
>
;
ComputeBasePtrOfStridedBatch
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -1092,7 +1057,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
ors_grid_desc_m_
,
arg
.
ors_block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
ors_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
y_grid_desc_m_o_
),
arg
.
compute_base_ptr_of_batch_
);
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_bacckward_ydotygrad.hpp
View file @
cf9ef868
...
...
@@ -21,27 +21,12 @@
namespace
ck
{
template
<
typename
InputDataType
,
typename
GemmDataType
,
typename
FloatGemmAcc
,
typename
FloatORS
,
typename
CGridDesc_M_N
,
typename
ORSGridDesc_M
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
B1K1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
ABlockLdsExtraM
,
index_t
BBlockLdsExtraN
,
bool
Deterministic
>
index_t
NPerBlock
>
struct
GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -56,44 +41,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
// Gemm1
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
...
...
@@ -110,12 +65,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
Gemm1
NPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
y_grid_desc_mblock_mperblock_oblock_operblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
Gemm1
NPerBlock
>
{}))),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
...
...
@@ -141,7 +96,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1
NPerBlock
,
CGridDesc_M_N
>
(
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
...
...
@@ -151,64 +106,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
// S / dP Gemm (type 1 rcr)
struct
Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
static
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// Blockwise gemm with transposed XDL output
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
GemmDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
;
// TransposeC
static
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
static
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
...
...
@@ -223,18 +120,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
Float
GemmAcc
,
Float
ORS
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
true
>
;
using
DstBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
Float
GemmAcc
,
ThreadSliceLength_M
,
true
>
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
Float
ORS
,
ThreadSliceLength_M
,
true
>
;
};
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
Gemm1
NPerBlock
>
;
using
YDotYGrad_M_O
=
YDotYGrad_M_O_
<
BlockSize
,
MPerBlock
,
NPerBlock
>
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
return
MPerBlock
*
sizeof
(
Float
GemmAcc
);
return
MPerBlock
*
sizeof
(
Float
ORS
);
}
__device__
static
void
test
()
{}
...
...
@@ -246,8 +143,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
ORSGridDesc_M
&
ors_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
index_t
block_idx_m
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
...
...
@@ -269,7 +165,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
return
;
}
const
index_t
block_work_idx_m
=
Deterministic
?
block_idx_m
:
block_work_idx
[
I0
];
const
index_t
block_work_idx_m
=
block_work_idx
[
I0
];
constexpr
auto
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
));
...
...
@@ -306,7 +202,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
InputDataType
,
Float
GemmAcc
,
Float
ORS
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
...
...
@@ -321,13 +217,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_dot_ygrad_thread_accum_buf
=
typename
YDotYGrad_M_O
::
DstBufType
{};
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatGemmAcc
*>
(
p_shared
),
MPerBlock
);
if
constexpr
(
Deterministic
)
{
block_sync_lds
();
}
auto
y_dot_ygrad_block_accum_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatORS
*>
(
p_shared
),
MPerBlock
);
// clear accum buffers
y_dot_ygrad_thread_accum_buf
.
Clear
();
...
...
@@ -366,7 +257,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_YDotYGrad
MakeORSGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
ors_grid_desc_m
);
auto
ors_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
Float
GemmAcc
,
Float
ORS
,
FloatORS
,
decltype
(
ors_thread_desc_mblock_mrepeat_mwave_mperxdl
),
decltype
(
ors_grid_desc_mblock_mrepeat_mwave_mperxdl
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment