Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dd6a8de4
Commit
dd6a8de4
authored
Apr 06, 2022
by
Jehandad Khan
Browse files
Merge branch 'develop' into jd/dev_pkg
parents
0aa899aa
abf4bdb9
Changes
470
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2598 additions
and
1597 deletions
+2598
-1597
include/ck/tensor_operation/gpu/device/device_reduce_multiblock_atomic_add.hpp
...ration/gpu/device/device_reduce_multiblock_atomic_add.hpp
+83
-66
include/ck/tensor_operation/gpu/device/device_reduce_multiblock_partial_reduce.hpp
...on/gpu/device/device_reduce_multiblock_partial_reduce.hpp
+91
-71
include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp
.../tensor_operation/gpu/device/device_reduce_threadwise.hpp
+85
-65
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
...de/ck/tensor_operation/gpu/device/gemm_specialization.hpp
+7
-1
include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp
...ensor_operation/gpu/device/reduction_operator_mapping.hpp
+16
-16
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+16
-3
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+34
-5
include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp
...r_operation/gpu/element/element_wise_reduce_operation.hpp
+24
-0
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp
...or_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp
+265
-304
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp
.../gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp
+72
-71
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_partial_reduce.hpp
.../grid/gridwise_2d_reduction_multiblock_partial_reduce.hpp
+139
-166
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
...r_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
+100
-90
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_xdlops_v2r3.hpp
..._operation/gpu/grid/gridwise_batched_gemm_xdlops_v2r3.hpp
+0
-649
include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp
...or_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp
+11
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp
...ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp
+11
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
...ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
+11
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp
...e/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp
+8
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp
...e/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp
+49
-49
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+892
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+684
-0
No files found.
include/ck/tensor_operation/gpu/device/device_reduce_multiblock_atomic_add.hpp
View file @
dd6a8de4
...
...
@@ -17,8 +17,8 @@ namespace device {
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
int
Rank
,
typename
ReduceDim
s
,
in
dex_
t
Rank
,
index_t
Num
ReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
...
...
@@ -39,13 +39,18 @@ struct DeviceReduceMultiBlockAtomicAdd
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
IndexDataType
=
int32_t
;
using
InvariantDims
=
decltype
(
get_invariant_dims
<
Rank
,
ReduceDim
s
>
())
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
Num
ReduceDim
;
static
constexpr
index_t
s
rcDim
s
=
Rank
;
static
constexpr
index_t
d
stDim
s
=
(
InvariantDim
s
::
Size
()
==
0
)
?
1
:
InvariantDim
s
::
Size
()
;
static
constexpr
bool
reduceAllDim
s
=
(
InvariantDim
s
::
Size
()
==
0
);
static
constexpr
index_t
numS
rcDim
=
Rank
;
static
constexpr
index_t
numD
stDim
=
(
Num
InvariantDim
==
0
)
?
1
:
Num
InvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
Num
InvariantDim
==
0
);
static
constexpr
bool
support_AtomicAdd
=
std
::
is_same
<
OutDataType
,
float
>::
value
||
std
::
is_same
<
OutDataType
,
double
>::
value
;
...
...
@@ -62,18 +67,18 @@ struct DeviceReduceMultiBlockAtomicAdd
int
blkGroupSize
,
int
kBlockTileIterations
)
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
s
rcDim
s
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
s
rcDim
s
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numS
rcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numS
rcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
s
)
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
s
rcDim
s
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numS
rcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
...
...
@@ -84,7 +89,10 @@ struct DeviceReduceMultiBlockAtomicAdd
}
else
{
const
auto
toReduceDimLengths
=
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
...
...
@@ -92,23 +100,24 @@ struct DeviceReduceMultiBlockAtomicAdd
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
toR
educeDimLengths
)),
make_merge_transform
(
r
educeDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
outerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
innerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
kBlockTileIterations
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
innerLen
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
inPad_M
),
make_right_pad_transform
(
innerLen
,
inPad_K
)),
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -118,24 +127,25 @@ struct DeviceReduceMultiBlockAtomicAdd
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
)
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
d
stDim
s
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
d
stDim
s
>
{});
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numD
stDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numD
stDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
d
stDim
s
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numD
stDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
outerLen
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
outPad
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
outPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
outPad
)),
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
outPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
...
...
@@ -143,43 +153,44 @@ struct DeviceReduceMultiBlockAtomicAdd
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
Argument
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
:
in_dev_
{
in_dev
},
out_dev_
{
out_dev
}
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
:
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
(
void
)
out_indices_dev
;
(
void
)
workspace_dev
;
inLengths_
=
inLengths
;
inStrides_
=
inStrides
;
outLengths_
=
outLengths
;
outStrides_
=
outStrides
;
in_elementwise_op_
=
in_elementwise_op
;
acc_elementwise_op_
=
acc_elementwise_op
;
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
alpha_
=
static_cas
t
<
AccDataType
>
(
alpha
);
beta_
=
static_cast
<
Out
DataType
>
(
beta
);
alpha_
=
type_conver
t
<
AccDataType
>
(
alpha
);
beta_
=
type_convert
<
Acc
DataType
>
(
beta
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
ReduceDim
s
>
(
inLengths
);
get_2d_lengths
<
Rank
,
Num
ReduceDim
>
(
inLengths
_
);
if
constexpr
(
InvariantDim
s
::
Size
()
==
0
)
if
constexpr
(
Num
InvariantDim
==
0
)
invariant_lowest_length
=
1
;
else
invariant_lowest_length
=
inLengths
[
InvariantDims
::
At
(
InvariantDim
s
::
Size
()
-
1
)
];
invariant_lowest_length
=
inLengths
_
[
Num
InvariantDim
-
1
];
reduce_lowest_length
=
inLengths
[
R
educeDims
::
At
(
ReduceDims
::
Size
()
-
1
)
];
reduce_lowest_length
=
inLengths
_
[
R
ank
-
1
];
int
iterations
=
1
;
while
(
true
)
...
...
@@ -212,7 +223,7 @@ struct DeviceReduceMultiBlockAtomicAdd
std
::
vector
<
int
>
outStrides_
;
AccDataType
alpha_
;
Out
DataType
beta_
;
Acc
DataType
beta_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
...
...
@@ -330,18 +341,22 @@ struct DeviceReduceMultiBlockAtomicAdd
if
constexpr
(
InSrcVectorDim
==
0
)
{
if
constexpr
(
InvariantDims
::
Size
()
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
{
return
(
false
);
if
(
pArg
->
inStrides_
[
InvariantDims
::
At
(
InvariantDims
::
Size
()
-
1
)]
!=
1
)
}
else
{
if
(
pArg
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
invariant_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
}
else
{
if
(
pArg
->
inStrides_
[
R
educeDims
::
At
(
ReduceDims
::
Size
()
-
1
)
]
!=
1
)
if
(
pArg
->
inStrides_
[
R
ank
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
...
...
@@ -367,23 +382,25 @@ struct DeviceReduceMultiBlockAtomicAdd
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
override
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
reduceDims
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
...
...
include/ck/tensor_operation/gpu/device/device_reduce_multiblock_partial_reduce.hpp
View file @
dd6a8de4
...
...
@@ -15,8 +15,8 @@ namespace device {
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
int
Rank
,
typename
ReduceDim
s
,
in
dex_
t
Rank
,
index_t
Num
ReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
...
...
@@ -37,26 +37,35 @@ struct DeviceReduceMultiBlockPartialReduce
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
OutDstVectorSize
==
1
,
"OutDstVectorSize must be 1 for MultiBlockPartialReduce!"
);
using
IndexDataType
=
int32_t
;
using
InvariantDims
=
decltype
(
get_invariant_dims
<
Rank
,
ReduceDim
s
>
())
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
Num
ReduceDim
;
static
constexpr
index_t
s
rcDim
s
=
Rank
;
static
constexpr
index_t
d
stDim
s
=
(
InvariantDim
s
::
Size
()
==
0
)
?
1
:
InvariantDim
s
::
Size
()
;
static
constexpr
bool
reduceAllDim
s
=
(
InvariantDim
s
::
Size
()
==
0
);
static
constexpr
index_t
numS
rcDim
=
Rank
;
static
constexpr
index_t
numD
stDim
=
(
Num
InvariantDim
==
0
)
?
1
:
Num
InvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
Num
InvariantDim
==
0
);
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
size_t
GetWorkspaceSizeInBytes
(
const
std
::
vector
<
int
>&
inLengths
)
override
static
constexpr
int
MaxBlockGroupSize
=
256
;
long_index_t
GetWorkspaceSizeInBytes
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
reduceDims
)
override
{
size_t
invariant_total_length
;
size_t
reduce_total_length
;
auto
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
ReduceDim
s
>
(
inLengths
);
get_2d_lengths
<
Rank
,
Num
ReduceDim
>
(
inLengths
_
);
int
iterations
=
1
;
while
(
true
)
...
...
@@ -64,8 +73,7 @@ struct DeviceReduceMultiBlockPartialReduce
int
testBlkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 128
if
(
testBlkGroupSize
<=
128
)
if
(
testBlkGroupSize
<=
MaxBlockGroupSize
)
break
;
iterations
++
;
...
...
@@ -74,11 +82,12 @@ struct DeviceReduceMultiBlockPartialReduce
int
blkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
size
_t
workspace_size
=
invariant_total_length
*
blkGroupSize
;
long_index
_t
workspace_size
=
invariant_total_length
*
blkGroupSize
;
size_t
wsSizeInBytes
=
!
NeedIndices
?
workspace_size
*
sizeof
(
AccDataType
)
:
workspace_size
*
(
sizeof
(
AccDataType
)
+
sizeof
(
int
))
+
64
+
sizeof
(
int
);
long_index_t
wsSizeInBytes
=
!
NeedIndices
?
workspace_size
*
sizeof
(
AccDataType
)
:
workspace_size
*
(
sizeof
(
AccDataType
)
+
sizeof
(
int32_t
))
+
64
+
sizeof
(
int
);
return
(
wsSizeInBytes
);
};
...
...
@@ -90,18 +99,18 @@ struct DeviceReduceMultiBlockPartialReduce
int
blkGroupSize
,
int
kBlockTileIterations
)
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
s
rcDim
s
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
s
rcDim
s
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numS
rcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numS
rcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
s
)
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
s
rcDim
s
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numS
rcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
...
...
@@ -112,7 +121,10 @@ struct DeviceReduceMultiBlockPartialReduce
}
else
{
const
auto
toReduceDimLengths
=
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
...
...
@@ -120,38 +132,41 @@ struct DeviceReduceMultiBlockPartialReduce
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
toR
educeDimLengths
)),
make_merge_transform
(
r
educeDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
outerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
innerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
kBlockTileIterations
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
innerLen
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
inPad_M
),
make_right_pad_transform
(
innerLen
,
inPad_K
)),
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeWorkspace2dDescriptor
(
int
outerLen
,
int
blkGroupSize
)
static
auto
MakeWorkspace2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
auto
ws_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
outerLen
,
blkGroupSize
));
auto
ws_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
const
auto
wsPad
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
wsPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
ws_desc_m_k_padded
=
transform_tensor_descriptor
(
ws_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
wsPad
),
make_tuple
(
make_right_pad_transform
(
invariantLength
,
wsPad
),
make_pass_through_transform
(
blkGroupSize
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -161,43 +176,43 @@ struct DeviceReduceMultiBlockPartialReduce
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
const
std
::
vector
<
index_t
>&
outLengths
,
const
std
::
vector
<
index_t
>&
outStrides
,
Argument
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
:
in_dev_
{
in_dev
},
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
:
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
},
workspace_dev_
{
workspace_dev
}
workspace_dev_
{
workspace_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
inLengths_
=
inLengths
;
inStrides_
=
inStrides
;
outLengths_
=
outLengths
;
outStrides_
=
outStrides
;
in_elementwise_op_
=
in_elementwise_op
;
acc_elementwise_op_
=
acc_elementwise_op
;
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
alpha_
=
static_cas
t
<
AccDataType
>
(
alpha
);
beta_
=
static_cast
<
Out
DataType
>
(
beta
);
alpha_
=
type_conver
t
<
AccDataType
>
(
alpha
);
beta_
=
type_convert
<
Acc
DataType
>
(
beta
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
ReduceDim
s
>
(
inLengths
);
get_2d_lengths
<
Rank
,
Num
ReduceDim
>
(
inLengths
_
);
if
constexpr
(
InvariantDim
s
::
Size
()
==
0
)
if
constexpr
(
Num
InvariantDim
==
0
)
invariant_lowest_length
=
1
;
else
invariant_lowest_length
=
inLengths
[
InvariantDims
::
At
(
InvariantDim
s
::
Size
()
-
1
)
];
invariant_lowest_length
=
inLengths
_
[
Num
InvariantDim
-
1
];
reduce_lowest_length
=
inLengths
[
R
educeDims
::
At
(
ReduceDims
::
Size
()
-
1
)
];
reduce_lowest_length
=
inLengths
_
[
R
ank
-
1
];
int
iterations
=
1
;
while
(
true
)
...
...
@@ -205,8 +220,7 @@ struct DeviceReduceMultiBlockPartialReduce
int
testBlkGroupSize
=
(
reduce_total_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 128
if
(
testBlkGroupSize
<=
128
)
if
(
testBlkGroupSize
<=
MaxBlockGroupSize
)
break
;
iterations
++
;
...
...
@@ -236,7 +250,7 @@ struct DeviceReduceMultiBlockPartialReduce
std
::
vector
<
int
>
outStrides_
;
AccDataType
alpha_
;
Out
DataType
beta_
;
Acc
DataType
beta_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
...
...
@@ -334,18 +348,22 @@ struct DeviceReduceMultiBlockPartialReduce
if
constexpr
(
InSrcVectorDim
==
0
)
{
if
constexpr
(
InvariantDims
::
Size
()
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
{
return
(
false
);
if
(
pArg
->
inStrides_
[
InvariantDims
::
At
(
InvariantDims
::
Size
()
-
1
)]
!=
1
)
}
else
{
if
(
pArg
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
invariant_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
}
else
{
if
(
pArg
->
inStrides_
[
R
educeDims
::
At
(
ReduceDims
::
Size
()
-
1
)
]
!=
1
)
if
(
pArg
->
inStrides_
[
R
ank
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
...
...
@@ -368,23 +386,25 @@ struct DeviceReduceMultiBlockPartialReduce
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
AccElementwiseOperation
&
acc_elementwise_op
)
override
const
InElementwiseOperation
in_elementwise_op
,
const
AccElementwiseOperation
acc_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
reduceDims
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
...
...
include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp
View file @
dd6a8de4
...
...
@@ -16,7 +16,7 @@ template <typename InDataType,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
typename
ReduceDim
s
,
index_t
Num
ReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
OutElementwiseOperation
,
...
...
@@ -36,15 +36,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static_assert
((
BlockSize
==
MThreadClusterSize
)
&&
(
KThreadClusterSize
==
1
),
"Threadwise can only be called with KThreadClusterSize be 1 !"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
IndexDataType
=
int32_t
;
static
constexpr
bool
BetaIsZero
=
NeedIndices
;
using
InvariantDims
=
decltype
(
get_invariant_dims
<
Rank
,
ReduceDim
s
>
())
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
Num
ReduceDim
;
static
constexpr
index_t
s
rcDim
s
=
Rank
;
static
constexpr
index_t
d
stDim
s
=
(
InvariantDim
s
::
Size
()
==
0
)
?
1
:
InvariantDim
s
::
Size
()
;
static
constexpr
bool
reduceAllDim
s
=
(
InvariantDim
s
::
Size
()
==
0
);
static
constexpr
index_t
numS
rcDim
=
Rank
;
static
constexpr
index_t
numD
stDim
=
(
Num
InvariantDim
==
0
)
?
1
:
Num
InvariantDim
;
static
constexpr
bool
reduceAllDim
=
(
Num
InvariantDim
==
0
);
static
constexpr
int
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
int
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
...
@@ -52,18 +57,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
)
{
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
s
rcDim
s
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
s
rcDim
s
>
{});
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numS
rcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numS
rcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
s
)
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
s
rcDim
s
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numS
rcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
...
...
@@ -74,7 +79,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
}
else
{
const
auto
toReduceDimLengths
=
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
...
...
@@ -82,22 +90,24 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
toR
educeDimLengths
)),
make_merge_transform
(
r
educeDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
outerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
innerLen
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
inPad_M
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
inPad_K
=
math
::
integer_least_multiple
(
innerLen
,
K_BlockTileSize
)
-
innerLen
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
math
::
integer_least_multiple
(
reduceLength
,
K_BlockTileSize
)
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
inPad_M
),
make_right_pad_transform
(
innerLen
,
inPad_K
)),
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -107,24 +117,25 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static
auto
MakeDst1dDescriptor
(
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
)
{
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
d
stDim
s
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
d
stDim
s
>
{});
const
auto
tupleDstLengths
=
make_tuple_from_array
(
outLengths
,
Number
<
numD
stDim
>
{});
const
auto
tupleDstStrides
=
make_tuple_from_array
(
outStrides
,
Number
<
numD
stDim
>
{});
auto
outDesc
=
make_naive_tensor_descriptor
(
tupleDstLengths
,
tupleDstStrides
);
auto
out_grid_desc_m
=
transform_tensor_descriptor
(
outDesc
,
make_tuple
(
make_merge_transform
(
tupleDstLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
d
stDim
s
,
1
>::
type
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numD
stDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
outerLen
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
invariantLength
=
out_grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
outPad
=
math
::
integer_least_multiple
(
outerLen
,
M_BlockTileSize
)
-
outerLen
;
const
auto
outPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
outerLen
,
outPad
)),
auto
out_grid_desc_m_padded
=
transform_tensor_descriptor
(
out_grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
outPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
out_grid_desc_m_padded
);
...
...
@@ -132,42 +143,45 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
Argument
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_indices_dev
,
AccDataType
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
OutElementwiseOperation
&
acc_elementwise_op
)
:
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
}
const
InElementwiseOperation
in_elementwise_op
,
const
OutElementwiseOperation
acc_elementwise_op
)
:
outLengths_
{
outLengths
},
outStrides_
{
outStrides
},
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
out_indices_dev_
{
out_indices_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
(
void
)
workspace_dev
;
inLengths_
=
inLengths
;
inStrides_
=
inStrides
;
outLengths_
=
outLengths
;
outStrides_
=
outStrides
;
in_elementwise_op_
=
in_elementwise_op
;
acc_elementwise_op_
=
acc_elementwise_op
;
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
alpha_
=
static_cas
t
<
AccDataType
>
(
alpha
);
beta_
=
static_cast
<
Out
DataType
>
(
beta
);
alpha_
=
type_conver
t
<
AccDataType
>
(
alpha
);
beta_
=
type_convert
<
Acc
DataType
>
(
beta
);
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
ReduceDim
s
>
(
inLengths
);
get_2d_lengths
<
Rank
,
Num
ReduceDim
>
(
inLengths
_
);
if
constexpr
(
InvariantDim
s
::
Size
()
==
0
)
if
constexpr
(
Num
InvariantDim
==
0
)
invariant_lowest_length
=
1
;
else
invariant_lowest_length
=
inLengths
[
InvariantDims
::
At
(
InvariantDim
s
::
Size
()
-
1
)
];
invariant_lowest_length
=
inLengths
_
[
Num
InvariantDim
-
1
];
reduce_lowest_length
=
inLengths
[
R
educeDims
::
At
(
ReduceDims
::
Size
()
-
1
)
];
reduce_lowest_length
=
inLengths
_
[
R
ank
-
1
];
gridSize
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
;
...
...
@@ -179,7 +193,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
std
::
vector
<
int
>
outStrides_
;
AccDataType
alpha_
;
Out
DataType
beta_
;
Acc
DataType
beta_
;
const
InDataType
*
in_dev_
;
OutDataType
*
out_dev_
;
...
...
@@ -272,18 +286,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
if
constexpr
(
InSrcVectorDim
==
0
)
{
if
constexpr
(
InvariantDims
::
Size
()
==
0
)
if
constexpr
(
NumInvariantDim
==
0
)
{
return
(
false
);
if
(
pArg
->
inStrides_
[
InvariantDims
::
At
(
InvariantDims
::
Size
()
-
1
)]
!=
1
)
}
else
{
if
(
pArg
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
invariant_lowest_length
%
InSrcVectorSize
!=
0
)
return
(
false
);
};
}
else
{
if
(
pArg
->
inStrides_
[
R
educeDims
::
At
(
ReduceDims
::
Size
()
-
1
)
]
!=
1
)
if
(
pArg
->
inStrides_
[
R
ank
-
1
]
!=
1
)
return
(
false
);
if
(
pArg
->
reduce_lowest_length
%
InSrcVectorSize
!=
0
)
...
...
@@ -304,23 +322,25 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
int
>&
inLengths
,
const
std
::
vector
<
int
>&
inStrides
,
const
std
::
vector
<
int
>&
outLengths
,
const
std
::
vector
<
int
>&
outStrides
,
MakeArgumentPointer
(
const
std
::
vector
<
int
>
inLengths
,
const
std
::
vector
<
int
>
inStrides
,
const
std
::
vector
<
int
>
outLengths
,
const
std
::
vector
<
int
>
outStrides
,
const
std
::
vector
<
int
>
reduceDims
,
float
alpha
,
float
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
void
*
out_indices_dev
,
void
*
workspace_dev
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
OutElementwiseOperation
&
acc_elementwise_op
)
override
const
InElementwiseOperation
in_elementwise_op
,
const
OutElementwiseOperation
acc_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
outLengths
,
outStrides
,
reduceDims
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
...
...
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
View file @
dd6a8de4
...
...
@@ -5,10 +5,16 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
enum
GemmSpecialization
_t
enum
struct
GemmSpecialization
{
Default
,
MPadding
,
NPadding
,
KPadding
,
MNPadding
,
MKPadding
,
NKPadding
,
MNKPadding
,
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp
View file @
dd6a8de4
...
...
@@ -37,11 +37,11 @@ namespace ck {
// The boolean member "indexable" are also provided in reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels.
template
<
typename
T
,
ReduceTensorOp
_t
Op
>
template
<
typename
T
,
ReduceTensorOp
Op
>
struct
reduce_binary_operator
;
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
_t
::
ADD
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
ADD
>
{
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
...
...
@@ -50,7 +50,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
_t
::
MUL
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
MUL
>
{
using
opType
=
reduce
::
Mul
<
T
>
;
using
dataType
=
T
;
...
...
@@ -59,7 +59,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
_t
::
MIN
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
MIN
>
{
using
opType
=
reduce
::
Min
<
T
>
;
using
dataType
=
T
;
...
...
@@ -68,7 +68,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
_t
::
MAX
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
MAX
>
{
using
opType
=
reduce
::
Max
<
T
>
;
using
dataType
=
T
;
...
...
@@ -77,7 +77,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
_t
::
AMAX
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
AMAX
>
{
using
opType
=
reduce
::
AMax
<
T
>
;
using
dataType
=
T
;
...
...
@@ -86,7 +86,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
_t
::
AVG
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
AVG
>
{
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
...
...
@@ -95,7 +95,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
_t
::
NORM1
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
NORM1
>
{
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
...
...
@@ -104,7 +104,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
};
template
<
typename
T
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
_t
::
NORM2
>
struct
reduce_binary_operator
<
T
,
ReduceTensorOp
::
NORM2
>
{
using
opType
=
reduce
::
Add
<
T
>
;
using
dataType
=
T
;
...
...
@@ -115,7 +115,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
// functor classes.
// The two unary functors are called before and afer the Reduction is executed respectively
template
<
typename
T
,
ReduceTensorOp
_t
Op
,
bool
IsFirstReduce
,
bool
IsLastReduce
>
template
<
typename
T
,
ReduceTensorOp
Op
,
bool
IsFirstReduce
,
bool
IsLastReduce
>
struct
reduce_unary_operator
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
...
...
@@ -123,42 +123,42 @@ struct reduce_unary_operator
};
template
<
typename
T
,
bool
IsFirstReduce
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
_t
::
AVG
,
IsFirstReduce
,
true
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
AVG
,
IsFirstReduce
,
true
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
,
true
>
;
};
template
<
typename
T
,
bool
IsLastReduce
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
_t
::
NORM1
,
true
,
IsLastReduce
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
NORM1
,
true
,
IsLastReduce
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryAbs
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
};
template
<
typename
T
,
bool
IsLastReduce
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
_t
::
AMAX
,
true
,
IsLastReduce
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
AMAX
,
true
,
IsLastReduce
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryAbs
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
};
template
<
typename
T
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
_t
::
NORM2
,
true
,
false
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
NORM2
,
true
,
false
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySquare
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
};
template
<
typename
T
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
_t
::
NORM2
,
true
,
true
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
NORM2
,
true
,
true
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySquare
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySqrt
<
T
,
T
>
;
};
template
<
typename
T
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
_t
::
NORM2
,
false
,
true
>
struct
reduce_unary_operator
<
T
,
ReduceTensorOp
::
NORM2
,
false
,
true
>
{
using
InElementwiseOperation
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
AccElementwiseOperation
=
tensor_operation
::
element_wise
::
UnarySqrt
<
T
,
T
>
;
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
dd6a8de4
#ifndef TENSOR_LAYOUT_HPP
#define TENSOR_LAYOUT_HPP
#pragma once
namespace
ck
{
namespace
tensor_layout
{
...
...
@@ -85,6 +84,7 @@ struct NKHW : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NKHW"
;
};
// 3D Conv
struct
NDHWC
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NDHWC"
;
...
...
@@ -99,6 +99,20 @@ struct NDHWK : public BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NDHWK"
;
};
struct
NCDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NCDHW"
;
};
struct
KCZYX
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"KCZYX"
;
};
struct
NKDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NKDHW"
;
};
}
// namespace convolution
...
...
@@ -113,4 +127,3 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
}
// namespace tensor_layout
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
dd6a8de4
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP
#include "data_type.hpp"
#pragma once
#include "data_type.hpp"
namespace
ck
{
...
...
@@ -19,6 +16,8 @@ struct PassThrough
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
};
struct
Add
...
...
@@ -239,6 +238,24 @@ struct UnaryIdentic<int32_t, int32_t, false>
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
int32_t
,
int32_t
,
true
>
{
__host__
__device__
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
__host__
__device__
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
/
divider_
;
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnaryIdentic
<
int8_t
,
int8_t
,
false
>
{
__host__
__device__
UnaryIdentic
(
const
int8_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
};
};
template
<
typename
Y
,
typename
X
,
bool
HasDividing
=
false
>
struct
UnarySquare
;
...
...
@@ -311,6 +328,19 @@ struct UnaryAbs<double, double>
__host__
__device__
void
operator
()(
double
&
y
,
const
double
&
x
)
const
{
y
=
abs
(
x
);
};
};
template
<
>
struct
UnaryAbs
<
int8_t
,
int8_t
>
{
__host__
__device__
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
__host__
__device__
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
int8_t
sgn
=
x
>>
(
8
-
1
);
y
=
(
x
^
sgn
)
-
sgn
;
};
};
template
<
typename
Y
,
typename
X
>
struct
UnarySqrt
;
...
...
@@ -333,4 +363,3 @@ struct UnarySqrt<double, double>
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include "data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
struct
ReduceSum
{
__host__
__device__
static
constexpr
float
GetReduceZeroValue
()
{
return
float
(
0
);
}
__host__
__device__
void
Reduce
(
float
&
acc
,
float
v
)
const
{
acc
+=
v
;
}
};
struct
ReduceSquareSum
{
__host__
__device__
static
constexpr
float
GetReduceZeroValue
()
{
return
float
(
0
);
}
__host__
__device__
void
Reduce
(
float
&
acc
,
float
v
)
const
{
acc
+=
v
*
v
;
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp
View file @
dd6a8de4
...
...
@@ -31,8 +31,10 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
...
...
@@ -52,14 +54,16 @@ __global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
const
OutElementwiseOperation
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
Out
DataType
beta
,
Acc
DataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
if
constexpr
(
!
NeedIndices
)
{
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
constexpr
bool
IsSecondCall
=
false
;
GridwiseReduction
::
template
Run
<
IsSecondCall
>(
in_grid_desc_m_k
,
out_grid_desc_m
,
in_elementwise_op
,
acc_elementwise_op
,
...
...
@@ -102,14 +106,16 @@ kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
const
OutElementwiseOperation
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
Out
DataType
beta
,
Acc
DataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
if
constexpr
(
!
NeedIndices
)
{
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
constexpr
bool
IsSecondCall
=
true
;
GridwiseReduction
::
template
Run
<
IsSecondCall
>(
in_grid_desc_m_k
,
out_grid_desc_m
,
in_elementwise_op
,
acc_elementwise_op
,
...
...
@@ -156,64 +162,88 @@ template <typename InDataType,
index_t
OutDstVectorSize
>
struct
GridwiseReduction_mk_to_m_blockwise
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
static
constexpr
auto
buffer_1d_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BlockSize
>
{}));
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
template
<
typename
T
>
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
template
<
bool
IsSecondCall
>
__device__
static
void
Run
(
const
InGridDesc_M_K
&
in_grid_desc_m_k
,
const
OutGridDesc_M
&
out_grid_desc_m
,
const
InElementwiseOperation
&
in_elementwise_op
,
const
OutElementwiseOperation
&
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
Out
DataType
beta
,
Acc
DataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
using
BlockwiseReduce
=
PartitionedBlockwiseReductionOn1dBuffer
<
decltype
(
buffer_1d_desc
),
AccDataType
,
if
constexpr
(
IsSecondCall
)
{
static_assert
(
InSrcVectorDim
==
1
,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!"
);
};
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
reorder_thread_cluster
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
(
void
)
p_ws_indices_global
;
(
void
)
p_indices_global
;
// LDS
__shared__
AccDataType
p_
block_
reduce_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_
work_
buffer
[
BlockSize
];
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
block_
reduce_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_buffer
,
BlockSize
);
auto
reduce_
work_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
});
...
...
@@ -221,28 +251,28 @@ struct GridwiseReduction_mk_to_m_blockwise
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
const
index_t
thread_m_cluster_id
=
reorder_thread_cluster
?
thread_local_id
%
MThreadClusterSize
:
((
thread_local_id
/
KThreadClusterSize
)
%
MThreadClusterSize
);
const
index_t
thread_k_cluster_id
=
reorder_thread_cluster
?
((
thread_local_id
/
MThreadClusterSize
)
%
KT
hread
C
luster
Size
)
:
thread_local_id
%
KT
hread
C
luster
Size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
)
);
const
auto
thread_m_cluster_id
=
t
hread
_c
luster
_idx
[
I0
];
const
auto
thread_k_cluster_id
=
t
hread
_c
luster
_idx
[
I1
]
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
...
...
@@ -260,45 +290,26 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
in_elementwise_op
(
in_thread_buf
(
offset
),
in_thread_buf
(
offset
));
});
// reduce on each thread-local slice
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
Accumulation
::
Calculate
(
accu_value_buf
(
I
),
in_thread_buf
[
offset
]);
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op
(
in_thread_buf
(
Number
<
offset
>
{}),
in_thread_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
toReduceTiles
);
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
reorder_thread_cluster
)
{
block_reduce_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
accu_value_buf
[
I
];
}
else
block_reduce_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
accu_value_buf
[
I
];
accu_value_buf
(
I
)
=
zeroVal
;
__syncthreads
();
BlockwiseReduce
::
Reduce
(
block_reduce_buf
,
accu_value_buf
(
I
),
thread_m_cluster_id
,
thread_k_cluster_id
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
...
...
@@ -315,7 +326,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{
if
(
!
float_equal_zero
{}(
beta
))
{
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
...
...
@@ -340,7 +351,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
]
*
beta
)
;
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
]
)
*
beta
;
});
};
};
...
...
@@ -350,18 +361,18 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
AccDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
<
AccDataType
>
{});
PassThroughOp
{});
threadwise_dst_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
out_global_buf
);
...
...
@@ -374,19 +385,17 @@ struct GridwiseReduction_mk_to_m_blockwise
const
OutElementwiseOperation
&
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
Out
DataType
beta
,
Acc
DataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
using
BlockwiseReduceWithIndex
=
PartitionedBlockwiseReductionWithIndexOn1dBuffer
<
decltype
(
buffer_1d_desc
),
AccDataType
,
PartitionedBlockwiseReductionWithIndex
<
AccDataType
,
IndexDataType
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
reorder_thread_cluster
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
...
...
@@ -398,62 +407,61 @@ struct GridwiseReduction_mk_to_m_blockwise
(
void
)
p_ws_indices_global
;
// LDS
__shared__
AccDataType
p_
block_
reduce_val_buffer
[
BlockSize
];
__shared__
IndexDataType
p_
block_
reduce_idx_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_
work_
val_buffer
[
BlockSize
];
__shared__
IndexDataType
p_reduce_
work_
idx_buffer
[
BlockSize
];
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_indices_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
block_
reduce_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_val_buffer
,
BlockSize
);
auto
block_
reduce_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_idx_buffer
,
BlockSize
);
auto
reduce_
work_
val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
val_buffer
,
BlockSize
);
auto
reduce_
work_
idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
idx_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
index_t
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_idx_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
const
index_t
thread_m_cluster_id
=
reorder_thread_cluster
?
thread_local_id
%
MThreadClusterSize
:
((
thread_local_id
/
KThreadClusterSize
)
%
MThreadClusterSize
);
const
index_t
thread_k_cluster_id
=
reorder_thread_cluster
?
((
thread_local_id
/
MThreadClusterSize
)
%
KT
hread
C
luster
Size
)
:
thread_local_id
%
KT
hread
C
luster
Size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
)
);
const
auto
thread_m_cluster_id
=
t
hread
_c
luster
_idx
[
I0
];
const
auto
thread_k_cluster_id
=
t
hread
_c
luster
_idx
[
I1
]
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
...
...
@@ -479,56 +487,36 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
;
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf
(
offset
)
=
indexOffset
+
thread_k_cluster_id
*
KThreadSliceSize
+
J
();
in_thread_idx_buf
(
Number
<
offset
>
{}
)
=
indexOffset
+
thread_k_cluster_id
*
KThreadSliceSize
+
iK
();
// do element-wise pre-reduction operation
in_elementwise_op
(
in_thread_val_buf
(
offset
),
in_thread_val_buf
(
offset
));
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
in_thread_val_buf
(
Number
<
offset
>
{}));
});
AccDataType
tmpValue
=
zeroVal
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
;
// reduce on the dim1 thread slice
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
offset
],
tmpIndex
,
in_thread_idx_buf
[
offset
]);
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
Number
<
offset
>
{}],
tmpIndex
,
in_thread_idx_buf
[
Number
<
offset
>
{}]);
});
// store thread local value to LDS for parallel reduction
if
constexpr
(
reorder_thread_cluster
)
{
block_reduce_val_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
tmpValue
;
block_reduce_idx_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
tmpIndex
;
}
else
{
block_reduce_val_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
tmpValue
;
block_reduce_idx_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
tmpIndex
;
}
__syncthreads
();
BlockwiseReduceWithIndex
::
Reduce
(
block_reduce_val_buf
,
block_reduce_idx_buf
,
tmpValue
,
tmpIndex
,
thread_m_cluster_id
,
thread_k_cluster_id
);
BlockwiseReduceWithIndex
::
Reduce
(
reduce_work_val_buf
,
reduce_work_idx_buf
,
tmpValue
,
tmpIndex
);
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
I
),
tmpValue
,
accu_index_buf
(
I
),
tmpIndex
);
accu_value_buf
(
iM
),
tmpValue
,
accu_index_buf
(
iM
),
tmpIndex
);
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
...
...
@@ -537,8 +525,7 @@ struct GridwiseReduction_mk_to_m_blockwise
reducedTiles
++
;
}
while
(
reducedTiles
<
toReduceTiles
);
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
...
...
@@ -556,7 +543,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{
if
(
!
float_equal_zero
{}(
beta
))
{
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
...
...
@@ -581,7 +568,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
]
*
beta
)
;
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
]
)
*
beta
;
});
};
};
...
...
@@ -591,36 +578,36 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
AccDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
<
AccDataType
>
{});
PassThroughOp
{});
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
IndexDataType
,
IndexDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
index_t
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
<
index_t
>
{});
PassThroughOp
{});
threadwise_dst_val_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
...
...
@@ -642,19 +629,20 @@ struct GridwiseReduction_mk_to_m_blockwise
const
OutElementwiseOperation
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_ws_values_global
,
Out
DataType
beta
,
Acc
DataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
const
IndexDataType
*
const
__restrict__
p_ws_indices_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
static_assert
(
InSrcVectorDim
==
1
,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!"
);
using
BlockwiseReduceWithIndex
=
PartitionedBlockwiseReductionWithIndexOn1dBuffer
<
decltype
(
buffer_1d_desc
),
AccDataType
,
PartitionedBlockwiseReductionWithIndex
<
AccDataType
,
IndexDataType
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
reorder_thread_cluster
,
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
...
...
@@ -666,90 +654,86 @@ struct GridwiseReduction_mk_to_m_blockwise
(
void
)
in_elementwise_op
;
// LDS
__shared__
AccDataType
p_
block_
reduce_val_buffer
[
BlockSize
];
__shared__
IndexDataType
p_
block_
reduce_idx_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_
work_
val_buffer
[
BlockSize
];
__shared__
IndexDataType
p_reduce_
work_
idx_buffer
[
BlockSize
];
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
p_ws_values_global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_values_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_indices_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
());
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_indices_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
block_
reduce_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_val_buffer
,
BlockSize
);
auto
block_
reduce_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_idx_buffer
,
BlockSize
);
auto
reduce_
work_
val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
val_buffer
,
BlockSize
);
auto
reduce_
work_
idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
idx_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_idx_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
const
auto
toReduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
const
index_t
thread_m_cluster_id
=
reorder_thread_cluster
?
thread_local_id
%
MThreadClusterSize
:
((
thread_local_id
/
KThreadClusterSize
)
%
MThreadClusterSize
);
const
index_t
thread_k_cluster_id
=
reorder_thread_cluster
?
((
thread_local_id
/
MThreadClusterSize
)
%
KT
hread
C
luster
Size
)
:
thread_local_id
%
KT
hread
C
luster
Size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
)
);
const
auto
thread_m_cluster_id
=
t
hread
_c
luster
_idx
[
I0
];
const
auto
thread_k_cluster_id
=
t
hread
_c
luster
_idx
[
I1
]
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_val_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_val_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_src_idx_load
=
ThreadwiseTensorSliceTransfer_v2
<
IndexDataType
,
auto
threadwise_src_idx_load
=
ThreadwiseTensorSliceTransfer_v2
<
IndexDataType
,
IndexDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
// index_t indexOffset = 0;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
accu_index_buf
(
I
)
=
0
;
...
...
@@ -774,56 +758,33 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple
(
I0
,
I0
),
in_thread_idx_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
tmpValue
=
zeroVal
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
;
// reduce on the dim1 thread slice
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
offset
],
tmpIndex
,
in_thread_idx_buf
[
offset
]);
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
Number
<
offset
>
{}],
tmpIndex
,
in_thread_idx_buf
[
Number
<
offset
>
{}]);
});
// store thread local value to LDS for parallel reduction
if
constexpr
(
reorder_thread_cluster
)
{
block_reduce_val_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
tmpValue
;
block_reduce_idx_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
tmpIndex
;
}
else
{
block_reduce_val_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
tmpValue
;
block_reduce_idx_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
tmpIndex
;
}
__syncthreads
();
BlockwiseReduceWithIndex
::
Reduce
(
block_reduce_val_buf
,
block_reduce_idx_buf
,
tmpValue
,
tmpIndex
,
thread_m_cluster_id
,
thread_k_cluster_id
);
BlockwiseReduceWithIndex
::
Reduce
(
reduce_work_val_buf
,
reduce_work_idx_buf
,
tmpValue
,
tmpIndex
);
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
I
),
tmpValue
,
accu_index_buf
(
I
),
tmpIndex
);
accu_value_buf
(
iM
),
tmpValue
,
accu_index_buf
(
iM
),
tmpIndex
);
});
threadwise_src_val_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
threadwise_src_idx_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
// indexOffset += K_BlockTileSize;
reducedTiles
++
;
}
while
(
reducedTiles
<
toReduceTiles
);
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
...
...
@@ -841,7 +802,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{
if
(
!
float_equal_zero
{}(
beta
))
{
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValueBuf
;
auto
threadwise_dst_load
=
...
...
@@ -866,7 +827,7 @@ struct GridwiseReduction_mk_to_m_blockwise
priorDstValueBuf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
]
*
beta
)
;
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValueBuf
[
I
]
)
*
beta
;
});
};
};
...
...
@@ -876,36 +837,36 @@ struct GridwiseReduction_mk_to_m_blockwise
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
AccDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
<
AccDataType
>
{});
PassThroughOp
{});
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
IndexDataType
,
IndexDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
IndexDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
block_global_1d_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
<
index_t
>
{});
PassThroughOp
{});
threadwise_dst_val_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp
View file @
dd6a8de4
...
...
@@ -30,8 +30,10 @@
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
...
...
@@ -84,24 +86,46 @@ template <typename InDataType,
index_t
OutDstVectorSize
>
struct
GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
static
constexpr
auto
buffer_1d_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BlockSize
>
{}));
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
blockwise_reduce
=
PartitionedBlockwiseReductionOn1dBuffer
<
decltype
(
buffer_1d_desc
),
AccDataType
,
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
reorder_thread_cluster
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
template
<
typename
T
>
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
...
@@ -121,23 +145,20 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
// LDS
__shared__
AccDataType
p_
block_
reduce_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_
work_
buffer
[
BlockSize
];
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
block_
reduce_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_buffer
,
BlockSize
);
auto
reduce_
work_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
});
...
...
@@ -145,12 +166,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
index_t
thread_m_cluster_id
=
reorder_thread_cluster
?
thread_local_id
%
MThreadClusterSize
:
((
thread_local_id
/
KThreadClusterSize
)
%
MThreadClusterSize
);
const
index_t
thread_k_cluster_id
=
reorder_thread_cluster
?
((
thread_local_id
/
MThreadClusterSize
)
%
KT
hread
C
luster
Size
)
:
thread_local_id
%
KT
hread
C
luster
Size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
)
);
const
auto
thread_m_cluster_id
=
t
hread
_c
luster
_idx
[
I0
];
const
auto
thread_k_cluster_id
=
t
hread
_c
luster
_idx
[
I1
]
;
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
...
...
@@ -158,13 +179,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
...
...
@@ -185,49 +205,30 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
in_elementwise_op
(
in_thread_buf
(
offset
),
in_thread_buf
(
offset
));
});
// reduce on each thread-local slice
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
Accumulation
::
Calculate
(
accu_value_buf
(
I
),
in_thread_buf
[
offset
]);
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op
(
in_thread_buf
(
Number
<
offset
>
{}),
in_thread_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
reorder_thread_cluster
)
{
block_reduce_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
accu_value_buf
[
I
];
}
else
block_reduce_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
accu_value_buf
[
I
];
accu_value_buf
(
I
)
=
zeroVal
;
__syncthreads
();
blockwise_reduce
::
Reduce
(
block_reduce_buf
,
accu_value_buf
(
I
),
thread_m_cluster_id
,
thread_k_cluster_id
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
(
thread_k_cluster_id
==
0
)
...
...
@@ -245,18 +246,18 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
AccDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
1
,
true
>
(
out_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
<
AccDataType
>
{});
PassThroughOp
{});
threadwise_dst_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
out_global_buf
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_partial_reduce.hpp
View file @
dd6a8de4
...
...
@@ -23,15 +23,17 @@
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_
TWO_CALL
_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_
TWO_CALL
_HPP
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_
PARTIAL_REDUCE
_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_
PARTIAL_REDUCE
_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
...
...
@@ -101,15 +103,34 @@ template <typename InDataType,
index_t
OutDstVectorSize
>
struct
GridwiseReduction_mk_to_mk_multiblock_partial_reduce
{
static_assert
((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
OutDstVectorSize
==
1
,
"OutDstVectorSize must be 1 for MultiBlockPartialReduce!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
static
constexpr
auto
buffer1dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BlockSize
>
{}));
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
template
<
typename
T
>
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
...
@@ -124,17 +145,18 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType
*
const
__restrict__
p_ws_values_global
,
IndexDataType
*
const
__restrict__
p_ws_indices_global
)
{
using
BlockwiseReduce
=
PartitionedBlockwiseReductionOn1dBuffer
<
decltype
(
buffer1dDesc
),
AccDataType
,
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
reorder_thread_cluster
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
(
void
)
p_ws_indices_global
;
(
void
)
acc_elementwise_op
;
...
...
@@ -142,25 +164,22 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
// LDS
__shared__
AccDataType
p_
block_
reduce_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_
work_
buffer
[
BlockSize
];
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
p_src_global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_src_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
workspace_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
workspace_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_values_global
,
workspace_desc_m_k
.
GetElementSpaceSize
());
auto
block_
reduce_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_buffer
,
BlockSize
);
auto
reduce_
work_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
});
...
...
@@ -168,12 +187,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
index_t
thread_m_cluster_id
=
reorder_thread_cluster
?
thread_local_id
%
MThreadClusterSize
:
((
thread_local_id
/
KThreadClusterSize
)
%
MThreadClusterSize
);
const
index_t
thread_k_cluster_id
=
reorder_thread_cluster
?
((
thread_local_id
/
MThreadClusterSize
)
%
KT
hread
C
luster
Size
)
:
thread_local_id
%
KT
hread
C
luster
Size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
)
);
const
auto
thread_m_cluster_id
=
t
hread
_c
luster
_idx
[
I0
];
const
auto
thread_k_cluster_id
=
t
hread
_c
luster
_idx
[
I1
]
;
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
...
...
@@ -181,13 +200,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
...
...
@@ -208,47 +226,29 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
in_elementwise_op
(
in_thread_buf
(
offset
),
in_thread_buf
(
offset
));
});
// reduce on each thread-local slice
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
Accumulation
::
Calculate
(
accu_value_buf
(
I
),
in_thread_buf
[
offset
]);
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op
(
in_thread_buf
(
Number
<
offset
>
{}),
in_thread_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
// Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
reorder_thread_cluster
)
{
block_reduce_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
accu_value_buf
[
I
];
}
else
block_reduce_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
accu_value_buf
[
I
];
accu_value_buf
(
I
)
=
zeroVal
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
});
__syncthreads
();
BlockwiseReduce
::
Reduce
(
block_reduce_buf
,
accu_value_buf
(
I
),
thread_m_cluster_id
,
thread_k_cluster_id
);
});
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
if
(
thread_k_cluster_id
==
0
)
{
...
...
@@ -257,19 +257,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType
,
decltype
(
reduced_data_desc
),
WorkspaceDesc_M_K
,
PassThroughOp
<
AccDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
workspace_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
<
AccDataType
>
{});
PassThroughOp
{});
threadwise_workspace_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
,
I0
),
...
...
@@ -290,13 +290,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
IndexDataType
*
const
__restrict__
p_ws_indices_global
)
{
using
BlockwiseReduceWithIndex
=
PartitionedBlockwiseReductionWithIndexOn1dBuffer
<
decltype
(
buffer1dDesc
),
AccDataType
,
PartitionedBlockwiseReductionWithIndex
<
AccDataType
,
IndexDataType
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
reorder_thread_cluster
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ReduceOperation
,
PropagateNan
>
;
...
...
@@ -310,48 +308,44 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
// LDS
__shared__
AccDataType
p_
block_
reduce_val_buffer
[
BlockSize
];
__shared__
index_t
p_
block_
reduce_idx_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_
work_
val_buffer
[
BlockSize
];
__shared__
index_t
p_reduce_
work_
idx_buffer
[
BlockSize
];
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
p_src_global
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_src_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
workspace_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
workspace_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_values_global
,
workspace_desc_m_k
.
GetElementSpaceSize
());
auto
workspace_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
workspace_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ws_indices_global
,
workspace_desc_m_k
.
GetElementSpaceSize
());
auto
block_
reduce_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_val_buffer
,
BlockSize
);
auto
block_
reduce_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
p_
block_
reduce_idx_buffer
,
BlockSize
);
auto
reduce_
work_
val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
val_buffer
,
BlockSize
);
auto
reduce_
work_
idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_
work_
idx_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_idx_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
index_t
thread_m_cluster_id
=
reorder_thread_cluster
?
thread_local_id
%
MThreadClusterSize
:
((
thread_local_id
/
KThreadClusterSize
)
%
MThreadClusterSize
);
const
index_t
thread_k_cluster_id
=
reorder_thread_cluster
?
((
thread_local_id
/
MThreadClusterSize
)
%
KT
hread
C
luster
Size
)
:
thread_local_id
%
KT
hread
C
luster
Size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
)
);
const
auto
thread_m_cluster_id
=
t
hread
_c
luster
_idx
[
I0
];
const
auto
thread_k_cluster_id
=
t
hread
_c
luster
_idx
[
I1
]
;
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
...
...
@@ -359,13 +353,12 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
...
...
@@ -394,56 +387,36 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
;
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf
(
offset
)
=
indexOffset
+
thread_k_cluster_id
*
KThreadSliceSize
+
J
();
in_thread_idx_buf
(
Number
<
offset
>
{}
)
=
indexOffset
+
thread_k_cluster_id
*
KThreadSliceSize
+
iK
();
// do element-wise pre-reduction operation
in_elementwise_op
(
in_thread_val_buf
(
offset
),
in_thread_val_buf
(
offset
));
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
in_thread_val_buf
(
Number
<
offset
>
{}));
});
AccDataType
tmpValue
=
zeroVal
;
IndexDataType
tmpIndex
=
0
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
;
// reduce on the dim1 thread slice
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
offset
],
tmpIndex
,
in_thread_idx_buf
[
offset
]);
AccumulationWithIndex
::
Calculate
(
tmpValue
,
in_thread_val_buf
[
Number
<
offset
>
{}],
tmpIndex
,
in_thread_idx_buf
[
Number
<
offset
>
{}]);
});
// store thread local value to LDS for parallel reduction
if
constexpr
(
reorder_thread_cluster
)
{
block_reduce_val_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
tmpValue
;
block_reduce_idx_buf
(
thread_k_cluster_id
*
MThreadClusterSize
+
thread_m_cluster_id
)
=
tmpIndex
;
}
else
{
block_reduce_val_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
tmpValue
;
block_reduce_idx_buf
(
thread_m_cluster_id
*
KThreadClusterSize
+
thread_k_cluster_id
)
=
tmpIndex
;
}
__syncthreads
();
BlockwiseReduceWithIndex
::
Reduce
(
block_reduce_val_buf
,
block_reduce_idx_buf
,
tmpValue
,
tmpIndex
,
thread_m_cluster_id
,
thread_k_cluster_id
);
BlockwiseReduceWithIndex
::
Reduce
(
reduce_work_val_buf
,
reduce_work_idx_buf
,
tmpValue
,
tmpIndex
);
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
I
),
tmpValue
,
accu_index_buf
(
I
),
tmpIndex
);
accu_value_buf
(
iM
),
tmpValue
,
accu_index_buf
(
iM
),
tmpIndex
);
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
...
...
@@ -463,38 +436,38 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType
,
decltype
(
reduced_data_desc
),
WorkspaceDesc_M_K
,
PassThroughOp
<
AccDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
workspace_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
<
AccDataType
>
{});
PassThroughOp
{});
auto
threadwise_workspace_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
IndexDataType
,
IndexDataType
,
decltype
(
reduced_data_desc
),
WorkspaceDesc_M_K
,
PassThroughOp
<
IndexDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
workspace_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
<
IndexDataType
>
{});
PassThroughOp
{});
threadwise_workspace_val_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
,
I0
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp
View file @
dd6a8de4
...
...
@@ -30,7 +30,9 @@
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
...
...
@@ -50,7 +52,7 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const
AccElementwiseOperation
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
Out
DataType
beta
,
Acc
DataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
...
...
@@ -101,8 +103,20 @@ template <typename InDataType,
index_t
OutDstVectorSize
>
struct
GridwiseReduction_mk_to_m_threadwise
{
template
<
typename
T
>
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
T
,
T
>
;
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
MThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -112,30 +126,29 @@ struct GridwiseReduction_mk_to_m_threadwise
const
AccElementwiseOperation
&
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
Out
DataType
beta
,
Acc
DataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
PropagateNan
>
;
(
void
)
p_indices_global
;
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
});
...
...
@@ -147,17 +160,17 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
KThreadSliceSize
);
...
...
@@ -170,20 +183,17 @@ struct GridwiseReduction_mk_to_m_threadwise
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
in_elementwise_op
(
in_thread_buf
(
offset
),
in_thread_buf
(
offset
));
});
// reduce on each thread-local slice
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
Accumulation
::
Calculate
(
accu_value_buf
(
I
),
in_thread_buf
[
offset
]);
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_elementwise_op
(
in_thread_buf
(
Number
<
offset
>
{}),
in_thread_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
reducedLength
+=
KThreadSliceSize
;
...
...
@@ -195,8 +205,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf
(
I
)
*=
alpha
;
});
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
if
constexpr
(
!
BetaIsZero
)
{
...
...
@@ -215,7 +224,7 @@ struct GridwiseReduction_mk_to_m_threadwise
true
>
(
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
));
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
...
...
@@ -225,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise
priorDstValue_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValue_buf
[
I
]
*
beta
)
;
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValue_buf
[
I
]
)
*
beta
;
});
};
};
...
...
@@ -235,17 +244,17 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
AccDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
),
PassThroughOp
<
AccDataType
>
{});
PassThroughOp
{});
threadwise_dst_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
dst_global_buf
);
...
...
@@ -257,34 +266,39 @@ struct GridwiseReduction_mk_to_m_threadwise
const
AccElementwiseOperation
&
acc_elementwise_op
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_global
,
Out
DataType
beta
,
Acc
DataType
beta
,
OutDataType
*
const
__restrict__
p_out_global
,
IndexDataType
*
const
__restrict__
p_indices_global
)
{
using
AccumulationWithIndex
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
using
ThreadwiseReduceWithIndex
=
ThreadwiseReductionWithIndex
<
AccDataType
,
IndexDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
PropagateNan
>
;
(
void
)
acc_elementwise_op
;
const
auto
zeroVal
=
ReduceOperation
::
GetReductionZeroVal
();
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
type_convert
<
InDataType
>
(
zeroVal
));
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
out_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_indices_global
,
out_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
in_thread_
idx_
buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
MThreadSliceSize
,
true
>
accu_index_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
zeroVal
;
...
...
@@ -299,17 +313,17 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
InGridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
typename
conditional
<
InSrcVectorDim
==
0
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
,
0
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
KThreadSliceSize
);
...
...
@@ -321,26 +335,23 @@ struct GridwiseReduction_mk_to_m_threadwise
in_global_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
in_thread_
val_
buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// do element-wise pre-reduction operation
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
;
in_elementwise_op
(
in_thread_buf
(
offset
),
in_thread_buf
(
offset
));
});
in_thread_idx_buf
(
Number
<
offset
>
{})
=
indexStart
+
iK
();
// reduce on each thread-local slice
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
J
)
{
constexpr
auto
offset
=
I
*
Number
<
KThreadSliceSize
>
{}
+
J
;
AccumulationWithIndex
::
Calculate
(
accu_value_buf
(
I
),
in_thread_buf
[
offset
],
accu_index_buf
(
I
),
indexStart
+
J
);
in_elementwise_op
(
in_thread_val_buf
(
Number
<
offset
>
{}),
in_thread_val_buf
(
Number
<
offset
>
{}));
});
});
ThreadwiseReduceWithIndex
::
Reduce
(
in_thread_val_buf
,
in_thread_idx_buf
,
accu_value_buf
,
accu_index_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_step
);
indexStart
+=
KThreadSliceSize
;
...
...
@@ -354,8 +365,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf
(
I
)
*=
alpha
;
});
constexpr
auto
reduced_data_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
reduced_data_desc
=
ThreadReduceDstDesc_M
{};
if
constexpr
(
!
BetaIsZero
)
{
...
...
@@ -374,7 +384,7 @@ struct GridwiseReduction_mk_to_m_threadwise
false
>
(
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
));
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
OutDataType
,
MThreadSliceSize
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
out_grid_desc_m
,
...
...
@@ -384,7 +394,7 @@ struct GridwiseReduction_mk_to_m_threadwise
priorDstValue_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValue_buf
[
I
]
*
beta
)
;
accu_value_buf
(
I
)
+=
type_convert
<
AccDataType
>
(
priorDstValue_buf
[
I
]
)
*
beta
;
});
};
};
...
...
@@ -394,34 +404,34 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
AccDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
),
PassThroughOp
<
AccDataType
>
{});
PassThroughOp
{});
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
IndexDataType
,
IndexDataType
,
decltype
(
reduced_data_desc
),
OutGridDesc_M
,
PassThroughOp
<
IndexDataType
>
,
PassThroughOp
,
Sequence
<
MThreadSliceSize
>
,
Sequence
<
0
>
,
0
,
OutDstVectorSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
(
out_grid_desc_m
,
make_multi_index
(
thread_global_1d_id
*
MThreadSliceSize
),
PassThroughOp
<
IndexDataType
>
{});
PassThroughOp
{});
threadwise_dst_val_store
.
Run
(
reduced_data_desc
,
make_tuple
(
I0
),
accu_value_buf
,
out_grid_desc_m
,
out_global_val_buf
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_xdlops_v2r3.hpp
deleted
100644 → 0
View file @
0aa899aa
#ifndef CK_GRIDWISE_BATCHED_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_BATCHED_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseBatchedGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_G_K0_M_K1
,
typename
BGridDesc_G_K0_N_K1
,
typename
CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_G_K0_M_K1
a_grid_desc_g_k0_m_k1
,
const
BGridDesc_G_K0_N_K1
b_grid_desc_g_k0_n_k1
,
const
CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
__shared__
char
p_shared
[
GridwiseBatchedGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseBatchedGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
a_grid_desc_g_k0_m_k1
,
b_grid_desc_g_k0_n_k1
,
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_G_K0_M_K1
,
typename
BGridDesc_G_K0_N_K1
,
typename
CGridDesc_G_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_G_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_G_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseBatchedGemm_gk0mk1_gk0nk1_gmn_xdlops_v2r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_g_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
I1
,
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_g_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_g_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
I1
,
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_g_k0_n_k1
;
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
a_block_desc_g_k0_m_k1
=
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1
();
constexpr
auto
K0
=
a_block_desc_g_k0_m_k1
.
GetLength
(
I1
);
constexpr
auto
M
=
a_block_desc_g_k0_m_k1
.
GetLength
(
I2
);
constexpr
auto
a_block_desc_k0_m_k1
=
transform_tensor_descriptor
(
a_block_desc_g_k0_m_k1
,
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
K0
),
make_pass_through_transform
(
M
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
b_block_desc_g_k0_n_k1
=
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1
();
constexpr
auto
K0
=
b_block_desc_g_k0_n_k1
.
GetLength
(
I1
);
constexpr
auto
N
=
b_block_desc_g_k0_n_k1
.
GetLength
(
I2
);
constexpr
auto
b_block_desc_k0_n_k1
=
transform_tensor_descriptor
(
b_block_desc_g_k0_n_k1
,
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
K0
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_g_k0_m_k1
=
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_g_k0_n_k1
=
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_g_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_g_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_G_K0_M_K1
&
a_grid_desc_g_k0_m_k1
,
const
BGridDesc_G_K0_N_K1
&
b_grid_desc_g_k0_n_k1
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
M01
,
index_t
N01
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
// const auto G = a_grid_desc_g_k0_m_k1.GetLength(I0);
const
auto
K0
=
a_grid_desc_g_k0_m_k1
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_g_k0_m_k1
.
GetLength
(
I2
);
const
auto
N
=
b_grid_desc_g_k0_n_k1
.
GetLength
(
I2
);
if
(
!
(
M
==
c_grid_desc_g_m_n
.
GetLength
(
I1
)
&&
N
==
c_grid_desc_g_m_n
.
GetLength
(
I2
)
&&
K0
==
b_grid_desc_g_k0_n_k1
.
GetLength
(
I1
)
&&
K1
==
a_grid_desc_g_k0_m_k1
.
GetLength
(
I3
)
&&
K1
==
b_grid_desc_g_k0_n_k1
.
GetLength
(
I3
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
index_t
grid_size
=
G
*
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
K0PerBlock
)
>
1
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m_n
);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
cblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
G
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_g_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_g_m0_n0_block_cluster_adaptor
;
}
using
CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_G_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_G_M_N
{},
1
,
1
));
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_G_K0_M_K1
&
a_grid_desc_g_k0_m_k1
,
const
BGridDesc_G_K0_N_K1
&
b_grid_desc_g_k0_n_k1
,
const
CGridDesc_G_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_grid_desc_g_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_grid_desc_g_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_g_k0_m_k1
.
GetLength
(
I1
);
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
g_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_g_k0_m_k1
=
GetABlockDescriptor_BatchCount_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_g_k0_n_k1
=
GetBBlockDescriptor_BatchCount_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_G_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_g_k0_m_k1
),
decltype
(
a_block_desc_g_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc_g_k0_m_k1
,
make_multi_index
(
g_idx_on_grid
,
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_g_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_G_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_g_k0_n_k1
),
decltype
(
b_block_desc_g_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_g_k0_n_k1
,
make_multi_index
(
g_idx_on_grid
,
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_g_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0_n_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_g_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_g_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_g_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_g_k0_m_k1
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_g_k0_n_k1
,
b_grid_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_g_k0_m_k1
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_g_k0_n_k1
,
b_block_buf
);
}
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainKBlockLoop
)
{
index_t
k0_block_data_begin
=
0
;
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_g_k0_m_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_g_k0_n_k1
,
b_block_slice_copy_step
);
a_blockwise_copy
.
RunRead
(
a_grid_desc_g_k0_m_k1
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc_g_k0_n_k1
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc_g_k0_m_k1
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_g_k0_n_k1
,
b_block_buf
);
k0_block_data_begin
+=
K0PerBlock
;
}
while
(
k0_block_data_begin
<
(
K0
-
K0PerBlock
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
();
// constexpr auto G = c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
constexpr
auto
M0
=
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
N0
=
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
M1
=
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
N1
=
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M2
=
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M3
=
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
M4
=
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
N2
=
c_block_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I8
);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
),
CElementwiseOperation
,
Sequence
<
I1
,
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
g_idx_on_grid
,
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
]),
c_element_op
};
c_thread_copy
.
Run
(
c_thread_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_g_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
);
}
}
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp
View file @
dd6a8de4
...
...
@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_GK0_GM0_GM1_GK1
,
typename
BGridDesc_GK0_GN0_GN1_GK1
,
typename
CGridDesc_GM0_GM1_GN0_GN1
,
...
...
@@ -329,11 +329,11 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_gk0_gm0_gm10_gm11_gk1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_gk0_gn0_gn10_gn11_gk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
.
GetElementSpaceSize
());
const
auto
GK0
=
a_grid_desc_gk0_gm0_gm10_gm11_gk1
.
GetLength
(
I0
);
...
...
@@ -383,7 +383,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
GK0PerBlock
,
GM0
,
1
,
GM1PerBlockGM11
,
GK1
.
value
>
,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1
,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1
,
...
...
@@ -407,7 +407,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
GK0PerBlock
,
GN0
,
1
,
GN1PerBlockGN11
,
GK1
.
value
>
,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1
,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1
,
...
...
@@ -467,7 +467,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_thread_desc_bm0_bm1_bn0_bn1
.
GetElementSpaceSize
());
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
...
...
@@ -481,15 +481,15 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
GK0PerBlock
,
0
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
GK0PerBlock
,
0
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_block_desc_gk0_gm0_gm10_gm11_gk1
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_block_desc_gk0_gn0_gn10_gn11_gk1
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_block_desc_gk0_gm0_gm10_gm11_gk1
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_block_desc_gk0_gn0_gn10_gn11_gk1
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp
View file @
dd6a8de4
...
...
@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AKMGridDesc
,
typename
BKNGridDesc
,
typename
CMNGridDesc
,
...
...
@@ -268,11 +268,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_k_m0_m1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_k_n0_n1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_m0_m10_m11_n0_n10_n11_grid_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m0_m1_grid_desc
.
GetLength
(
I0
);
...
...
@@ -315,7 +315,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
>
,
ABlockTransferThreadSliceLengths_K_M0_M1
,
ABlockTransferThreadClusterLengths_K_M0_M1
,
...
...
@@ -341,7 +341,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
>
,
BBlockTransferThreadSliceLengths_K_N0_N1
,
BBlockTransferThreadClusterLengths_K_N0_N1
,
...
...
@@ -403,7 +403,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
...
...
@@ -428,15 +428,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
constexpr
auto
b_k_n0_n1_global_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_k_n0_n1_block_desc
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_k_m0_m1_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_k_n0_n1_block_desc
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
View file @
dd6a8de4
...
...
@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
CMNGridDesc
,
...
...
@@ -275,11 +275,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_k0_m0_m1_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_k0_n0_n1_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_m0_m10_m11_n0_n10_n11_grid_desc
.
GetElementSpaceSize
());
// divide block work by [M, N]
...
...
@@ -325,7 +325,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
KPerBlock
,
1
,
MPerBlockM1
,
K1
.
value
>
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
...
...
@@ -349,7 +349,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
KPerBlock
,
1
,
NPerBlockN1
,
K1
.
value
>
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
...
...
@@ -409,7 +409,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
_t
::
Vgpr
,
FloatAcc
>
(
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_m10_m11_n10_n11_thread_desc
.
GetElementSpaceSize
());
ThreadwiseTensorSliceSet_v1
<
FloatAcc
,
...
...
@@ -423,15 +423,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_k0_m0_m1_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_k0_n0_n1_k1_block_desc
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_k0_m0_m1_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_k0_n0_n1_k1_block_desc
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp
View file @
dd6a8de4
...
...
@@ -15,7 +15,7 @@ template <index_t BlockSize,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
...
...
@@ -84,11 +84,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_e_k_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_e_n_ho_wo_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_k_n_ho_wo_global_desc
.
GetElementSpaceSize
());
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
...
...
@@ -181,7 +181,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
E
,
KPerBlock
>
,
ABlockTransferThreadSliceLengths_E_K
,
ABlockTransferThreadClusterLengths_E_K
,
...
...
@@ -221,11 +221,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_shared_block
,
a_e_k_desc
.
GetElementSpaceSize
());
// register allocation for output
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
...
...
@@ -250,7 +250,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
BGlobalMoveSliceWindowStepHacks
{};
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp
View file @
dd6a8de4
...
...
@@ -20,7 +20,7 @@ template <typename GridwiseGemm,
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum
_t
ActivType
>
ActivTypeEnum
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -50,7 +50,7 @@ __global__ void
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
cblockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
{});
integral_constant
<
ActivTypeEnum
,
ActivType
>
{});
}
template
<
typename
GridwiseGemm
,
...
...
@@ -62,7 +62,7 @@ template <typename GridwiseGemm,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum
_t
ActivType
>
ActivTypeEnum
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -94,7 +94,7 @@ __global__ void
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
cblockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
{});
integral_constant
<
ActivTypeEnum
,
ActivType
>
{});
}
template
<
typename
GridwiseGemm
,
...
...
@@ -106,7 +106,7 @@ template <typename GridwiseGemm,
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum
_t
ActivType
>
ActivTypeEnum
ActivType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -140,14 +140,14 @@ __global__ void
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
cblockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
{},
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
{});
integral_constant
<
ActivTypeEnum
,
ActivType
>
{});
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_E0_E1_K_E2
,
typename
BGridDesc_E0_E1_N_Ho_Wo_E2
,
typename
CGridDesc_K_N_Ho_Wo
,
...
...
@@ -559,7 +559,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr
auto
bias_k0_k1_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
KPerThread
>
{}));
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatC
,
bias_k0_k1_thread_desc
.
GetElementSpaceSize
(),
true
>
...
...
@@ -602,10 +602,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
});
}
template
<
typename
CThreadBuff
,
typename
CThreadDesc_K1_N_H2_W2
,
ActivTypeEnum
_t
activ_type_
>
template
<
typename
CThreadBuff
,
typename
CThreadDesc_K1_N_H2_W2
,
ActivTypeEnum
activ_type_
>
__device__
static
void
Activation
(
CThreadBuff
&
c_thread_buf
,
const
CThreadDesc_K1_N_H2_W2
&
,
integral_constant
<
ActivTypeEnum
_t
,
activ_type_
>
)
integral_constant
<
ActivTypeEnum
,
activ_type_
>
)
{
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
CThreadDesc_K1_N_H2_W2
{};
...
...
@@ -737,7 +737,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
I1
,
Number
<
WoPerThread_2
>
{}));
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatC
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
.
GetElementSpaceSize
(),
true
>
...
...
@@ -783,7 +783,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
make_multi_index
(
k_block_work_id
,
...
...
@@ -843,7 +843,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
I1
,
Number
<
WoPerThreadx2
>
{}));
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatC
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc
.
GetElementSpaceSize
(),
true
>
...
...
@@ -874,7 +874,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
InMemoryDataOperationEnum
_t
::
Add
,
InMemoryDataOperationEnum
::
Add
,
1
,
true
>
(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
make_multi_index
(
k_block_work_id
,
...
...
@@ -964,7 +964,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
I1
,
E1
,
I1
,
KPerBlock
,
E2
>
,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
...
...
@@ -1023,11 +1023,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
0
,
0
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_shared_block
,
a_e0_e1_k0_k1_e2_block_copy_desc
.
GetElementSpaceSize
());
//// register allocation for output
// StaticBuffer<AddressSpaceEnum
_t
::Vgpr,
// StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc,
// c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
// true>
...
...
@@ -1050,7 +1050,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr
auto
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks
=
BGlobalStepHacks
{};
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc
.
GetElementSpaceSize
(),
true
>
...
...
@@ -1294,21 +1294,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
...
...
@@ -1344,7 +1344,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum
_t
ActivType
>
ActivTypeEnum
ActivType
>
__device__
static
void
ConvBiasActiv
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
...
...
@@ -1356,26 +1356,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2
&
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
cblockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
,
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
)
integral_constant
<
ActivTypeEnum
,
ActivType
>
)
{
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
{};
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum
,
ActivType
>
{};
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
...
...
@@ -1423,7 +1423,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum
_t
ActivType
>
ActivTypeEnum
ActivType
>
__device__
static
void
ConvBiasActivMaxpool
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
...
...
@@ -1437,28 +1437,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
cblockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
,
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
)
integral_constant
<
ActivTypeEnum
,
ActivType
>
)
{
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
{};
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum
,
ActivType
>
{};
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
...
...
@@ -1514,7 +1514,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
,
typename
CBlockIdToBlockClusterAdaptor_K_N_H_W
,
bool
HasMainE0BlockLoop
,
ActivTypeEnum
_t
ActivType
>
ActivTypeEnum
ActivType
>
__device__
static
void
ConvBiasActivResizeAdd
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
...
...
@@ -1527,26 +1527,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx
&
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_H_W
&
cblockid_to_k_n_h_w_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainE0BlockLoop
>
,
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
)
integral_constant
<
ActivTypeEnum
,
ActivType
>
)
{
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum
_t
,
ActivType
>
{};
static
constexpr
auto
activ_type
=
integral_constant
<
ActivTypeEnum
,
ActivType
>
{};
const
auto
bias_k0_k1_grid_desc
=
MakeBiasK0K1GridDescriptor
(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc
);
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_e0_e1_k0_k1_e2_grid_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc
.
GetElementSpaceSize
());
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_global
,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc
.
GetElementSpaceSize
());
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
_t
::
Global
>
(
auto
bias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias_global
,
bias_k0_k1_grid_desc
.
GetElementSpaceSize
());
constexpr
auto
c_k1_n_h2_w2_thread_gemm_desc
=
MakeCK1NH2W2ThreadDescriptor
();
// register allocation for output
StaticBuffer
<
AddressSpaceEnum
_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
c_k1_n_h2_w2_thread_gemm_desc
.
GetElementSpaceSize
(),
true
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatD
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DGridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainK0BlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_reduce_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatD
*
__restrict__
p_d0_grid
,
FloatD
*
__restrict__
p_d1_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
D0ReduceOperation
d0_reduce_op
,
const
D1ReduceOperation
d1_reduce_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_d0_grid
,
p_d1_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
}
template
<
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatReduceAcc
,
typename
FloatD
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
DGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
>
struct
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
// check NumGemmKPrefetchStage
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
auto
MakeDGridDescriptor_MBlock_MPerBlock
(
const
DGridDesc_M
&
d_grid_desc_m
)
{
const
auto
M
=
d_grid_desc_m
.
GetLength
(
I0
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
d_grid_desc_mblock_mperblock
=
transform_tensor_descriptor
(
d_grid_desc_m
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
return
d_grid_desc_mblock_mperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
// FIXME: remove
constexpr
auto
M01
=
I1
;
constexpr
auto
N01
=
I1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DGridDescriptor_MBlock_MPerBlock
=
remove_cvref_t
<
decltype
(
MakeDGridDescriptor_MBlock_MPerBlock
(
DGridDesc_M
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatD
*
__restrict__
p_d0_grid
,
FloatD
*
__restrict__
p_d1_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
D0ReduceOperation
&
d0_reduce_op
,
const
D1ReduceOperation
&
d1_reduce_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
DGridDescriptor_MBlock_MPerBlock
&
d_grid_desc_mblock_mperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d1_grid
,
d_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
BlockwiseTensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr
auto
c_reduce_block_desc_mperblock_nperblock
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
)),
make_freeze_transform
(
I0
),
make_pass_through_transform
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I3
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{}));
static_assert
(
CReduceThreadClusterLengths_MPerBlock_NPerBlock
::
At
(
I0
)
*
CReduceThreadClusterLengths_MPerBlock_NPerBlock
::
At
(
I1
)
==
BlockSize
,
"wrong!"
);
static_assert
((
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
)
%
CReduceThreadClusterLengths_MPerBlock_NPerBlock
::
At
(
I0
)
==
0
&&
(
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
)
%
CReduceThreadClusterLengths_MPerBlock_NPerBlock
::
At
(
I1
)
==
0
,
"wrong!"
);
constexpr
index_t
mreduce_per_thread
=
(
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
)
/
CReduceThreadClusterLengths_MPerBlock_NPerBlock
::
At
(
I0
);
constexpr
index_t
nreduce_per_thread
=
(
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
)
/
CReduceThreadClusterLengths_MPerBlock_NPerBlock
::
At
(
I1
);
constexpr
auto
c_reduce_thread_lengths_mperblock_nperblock
=
Sequence
<
mreduce_per_thread
,
nreduce_per_thread
>
{};
// VGPR c_reduce_thread_desc_mperblock_nperblock
constexpr
auto
c_reduce_thread_desc_mperblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{},
Number
<
nreduce_per_thread
>
{}));
// VGPR d_reduce_thread_desc_mperblock
constexpr
auto
d_reduce_thread_desc_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
mreduce_per_thread
>
{}));
// VGPR d_reduce_thread_desc_mblock_mperblock
constexpr
auto
d_reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
// TODO: this should be implemented as a blockwise reduction
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatCShuffle
>
(
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatCShuffle
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
d1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatCShuffle
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
// reduce: threadwise copy from LDS to VGPR
constexpr
auto
c_reduce_thread_cluster_desc
=
make_cluster_descriptor
(
CReduceThreadClusterLengths_MPerBlock_NPerBlock
{},
Sequence
<
1
,
0
>
{});
const
auto
c_reduce_thread_cluster_idx
=
c_reduce_thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
c_reduce_thread_data_idx_begin
=
c_reduce_thread_cluster_idx
*
c_reduce_thread_lengths_mperblock_nperblock
;
auto
c_reduce_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatCShuffle
,
FloatCShuffle
,
decltype
(
c_reduce_block_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_lengths_mperblock_nperblock
),
Sequence
<
0
,
1
>
,
1
,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
// reduce: copy from VGPR to global
auto
d0_reduce_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatCShuffle
,
FloatD
,
decltype
(
d_reduce_thread_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
DGlobalMemoryDataOperation
,
1
,
false
>
{
d_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
d1_reduce_thread_copy_vgpr_to_global
=
d0_reduce_thread_copy_vgpr_to_global
;
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
// reduce
{
// copy from LDS to VGPR
c_reduce_thread_copy_lds_to_vgpr
.
Run
(
c_reduce_block_desc_mperblock_nperblock
,
c_shuffle_block_buf
,
c_reduce_thread_desc_mperblock_nperblock
,
make_tuple
(
I0
,
I0
),
c_reduce_thread_buf
);
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
FloatReduceAcc
d0_acc
=
d0_reduce_op
.
GetReduceZeroValue
();
FloatReduceAcc
d1_acc
=
d1_reduce_op
.
GetReduceZeroValue
();
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
constexpr
auto
offset
=
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d0_reduce_op
.
Reduce
(
d0_acc
,
c_reduce_thread_buf
[
offset
]);
d1_reduce_op
.
Reduce
(
d1_acc
,
c_reduce_thread_buf
[
offset
]);
});
constexpr
index_t
out_offset
=
d_reduce_thread_desc_mperblock
.
CalculateOffset
(
make_tuple
(
im
));
d0_thread_buf
(
Number
<
out_offset
>
{})
=
d0_acc
;
d1_thread_buf
(
Number
<
out_offset
>
{})
=
d1_acc
;
});
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d0_thread_buf
,
d_grid_desc_mblock_mperblock
,
d0_grid_buf
);
d1_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d1_thread_buf
,
d_grid_desc_mblock_mperblock
,
d1_grid_buf
);
}
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
// move on D0
d0_reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
// move on D1
d1_reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
0 → 100644
View file @
dd6a8de4
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
bool
HasMainK0BlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
)
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
}
template
<
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
// check NumGemmKPrefetchStage
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
// FIXME: remove
constexpr
auto
M01
=
I1
;
constexpr
auto
N01
=
I1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4r1
<
BlockSize
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
BlockwiseTensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
};
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
9
10
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment