Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
38a90b6e
Unverified
Commit
38a90b6e
authored
Oct 20, 2021
by
Chao Liu
Committed by
GitHub
Oct 20, 2021
Browse files
Merge pull request #43 from ROCmSoftwarePlatform/develop
Merge develop into master
parents
88833bd9
c3018794
Changes
71
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4485 additions
and
580 deletions
+4485
-580
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
...ward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
+90
-13
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
...ht_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
+147
-0
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
...ht_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
+147
-0
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+132
-0
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
+144
-0
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
...clude/tensor_description/multi_index_transform_helper.hpp
+1
-1
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+41
-41
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+182
-260
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+666
-0
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
...sor_operation/gridwise_generic_2d_reduction_blockwise.hpp
+625
-0
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
...ation/gridwise_generic_2d_reduction_direct_threadwise.hpp
+503
-0
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
...eration/gridwise_generic_2d_reduction_direct_warpwise.hpp
+544
-0
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
...or_operation/gridwise_generic_2d_reduction_multiblock.hpp
+376
-0
composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp
...nclude/tensor_operation/reduction_functions_blockwise.hpp
+271
-0
composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp
...clude/tensor_operation/reduction_functions_threadwise.hpp
+141
-0
composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp
...include/tensor_operation/reduction_functions_warpwise.hpp
+371
-0
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+3
-5
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+25
-75
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+74
-183
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+2
-2
No files found.
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
View file @
38a90b6e
...
...
@@ -21,8 +21,8 @@ template <typename... Wei,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
IYTilda
Value
,
index_t
IXTilda
Value
,
typename
IYTilda
,
typename
IXTilda
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
...
...
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
IYTildaValue
>
,
Number
<
IXTildaValue
>
,
IYTilda
i_ytilda
,
IXTilda
i_xtilda
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
constexpr
auto
IYTilda
=
Number
<
IYTildaValue
>
{};
constexpr
auto
IXTilda
=
Number
<
IXTildaValue
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
...
...
@@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const
auto
WTildaSlice
=
IWTildaSliceEnd
-
IWTildaSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
IYT
ilda
,
YTilda
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
IXT
ilda
,
XTilda
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_yt
ilda
,
YTilda
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xt
ilda
,
XTilda
);
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
...
...
@@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYT
ilda
),
make_freeze_transform
(
IXT
ilda
),
make_freeze_transform
(
i_yt
ilda
),
make_freeze_transform
(
i_xt
ilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
...
...
@@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
IYT
ilda
),
make_freeze_transform
(
i_yt
ilda
),
make_slice_transform
(
HTilda
,
IHTildaSliceBegin
,
HTildaSlice
),
make_freeze_transform
(
IXT
ilda
),
make_freeze_transform
(
i_xt
ilda
),
make_slice_transform
(
WTilda
,
IWTildaSliceBegin
,
WTildaSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
...
...
@@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
in_gemmm_gemmn_grid_desc
);
}
// A: out
// B: wei
// C: in
// Number of GEMMs = 1
// GemmM = N * Ho * Wo
// GemmN = C
// GemmK = K
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1
(
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
/* wei_k_y_x_c_grid_desc */
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
38a90b6e
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
C
*
Y
*
X
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
38a90b6e
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: output tensor
const
auto
out_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
38a90b6e
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmK
=
N
*
Ho
*
Wo
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: output tensor
const
auto
out_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
38a90b6e
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: out
// B: in
// C: wei
// GemmM = K
// GemmN = Y * X * C
// GemmKTotal = N * Ho * Wo
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
,
typename
GemmKPadType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
,
GemmKPadType
GemmKPad
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmK0
=
GemmKPad
/
(
GemmKBatch
*
GemmK1
);
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
View file @
38a90b6e
...
...
@@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform(
return
LeftPad
<
LowLength
,
LeftPadLength
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
}
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
>
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_right_pad_transform
(
const
LowLength
&
low_length
,
const
RightPadLength
&
right_pad
,
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
38a90b6e
...
...
@@ -10,6 +10,7 @@ namespace ck {
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
...
...
@@ -29,14 +30,18 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
KPerBlock
=
K0
;
static
constexpr
index_t
K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferV2
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
16
>
,
MRepeat
*
NRepeat
,
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
...
@@ -162,7 +167,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
PerBlock
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
K
0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{})),
make_pass_through_transform
(
Number
<
K1
>
{})),
...
...
@@ -174,7 +179,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
PerBlock
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
K
0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{})),
make_pass_through_transform
(
Number
<
K1
>
{})),
...
...
@@ -195,48 +200,43 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
vector_type
<
FloatAB
,
K1
>
a_thread_vec
;
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
/
xdlops_gemm
.
KPerThread
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
make_tuple
(
k
0
,
I
0
,
I0
,
I0
,
I0
),
make_tuple
(
I
0
,
m
0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// read B
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1_block_desc
,
make_tuple
(
k0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1_block_desc
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
KPerThread
>::
type
;
static_for
<
0
,
K0
,
xdlops_gemm
.
K0PerXdlops
>
{}([
&
](
auto
k0
)
{
vector_type
<
FloatAB
,
K1
>
a_thread_vec
;
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
m0
,
0
,
0
,
i
))
>
{}];
});
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
i
)
{
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
k0
,
0
,
0
,
0
,
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
n
0
,
0
,
0
,
i
))
>
{}];
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
k
0
,
0
,
0
,
0
,
i
))
>
{}];
});
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
));
xdlops_gemm
.
template
Run
<
c_offset
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
);
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVector
(
Number
<
c_offset
>{}));
});
});
});
...
...
@@ -244,35 +244,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{}
,
I1
,
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0
>
{},
I1
,
I1
,
I1
,
Number
<
K1
>
{}));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
NRepeat
>
{}
,
I1
,
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0
>
{},
I1
,
I1
,
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}
,
Number
<
xdlops_gemm
.
GetNumXdlops
()
>
{}
));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_k0_m0_m1_m2_k1_block_desc
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
MRepeat
,
1
,
1
,
K1
>
,
Sequence
<
K0
,
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
K1
,
1
>
;
K
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_k0_n0_n1_n2_k1_block_desc
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
NRepeat
,
1
,
1
,
K1
>
,
Sequence
<
K0
,
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
K1
,
1
>
;
K
1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
38a90b6e
...
...
@@ -29,7 +29,7 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
const
AK0MK1GridDesc
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_
m1_m2_n
_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
...
...
@@ -132,7 +132,9 @@ template <index_t BlockSize,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
>
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsExtraM
,
bool
BBlockLdsExtraN
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -142,6 +144,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
...
@@ -151,14 +154,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -170,29 +193,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
)
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
)
{
// TODO: turn on this
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
if
(
!
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
K1
==
a_k0_m_k1_grid_desc
.
GetLength
(
I2
)
&&
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
))
return
false
;
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_k0_n_k1_grid_desc
.
GetLength
(
I0
)
&&
K1
==
a_k0_m_k1_grid_desc
.
GetLength
(
I2
)
&&
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
))
&&
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
);
return
true
;
}
__host__
__device__
static
constexpr
index_t
...
...
@@ -211,15 +250,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
...
...
@@ -231,8 +295,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
BlockwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
...
...
@@ -243,23 +308,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
#if 1
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
N0
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
#elif 1
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
M0
))),
make_tuple
(
Sequence
<
1
,
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
#endif
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}
,
1
,
1
));
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -294,14 +367,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// A matrix blockwise copy
auto
a_blockwise_copy
=
...
...
@@ -363,9 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// register
// sanity check
const
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
...
...
@@ -374,18 +468,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NRepeat
,
K1
>
{};
constexpr
auto
c_mr_nr_blk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
();
constexpr
auto
CBlkSize
=
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
.
GetElementSpaceSize
();
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
CBlkSize
>
,
c_mr_nr_blk_desc
.
GetElementSpaceSize
(),
true
>
c_thread_buf
;
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -460,9 +543,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2BlockDescriptor
();
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
M0
>
{},
Number
<
N0
>
{},
I1
,
I1
,
Number
<
M2
>
{},
I1
,
Number
<
M4
>
{},
I1
));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
...
...
@@ -477,224 +569,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
Float
C
,
ThreadwiseTensorSliceTransfer_v1r3
<
Float
Acc
,
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
),
Sequence
<
I1
,
I1
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
make_multi_index
(
0
,
0
,
0
,
0
,
m_thread_data_on_grid
/
(
M3
*
M4
),
m_thread_data_on_grid
%
(
M3
*
M4
)
/
M4
,
m_thread_data_on_grid
%
M4
,
n_thread_data_on_grid
)};
auto
init_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
return
c_thread_idx_
;
};
auto
mrepeat_plus_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
mrepeat_step_plus
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
mrepeat_step_plus
);
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
};
auto
nrepeat_plus_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
nrepeat_step_plus
=
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
nrepeat_step_plus
);
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
};
auto
mrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
mrepeat_step_plus
=
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
mrepeat_step_plus
);
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
};
auto
nrepeat_minus_copy
=
[
&
](
auto
c_thread_idx_
)
{
constexpr
auto
nrepeat_step_minus
=
make_multi_index
(
0
,
-
1
,
0
,
0
,
0
,
0
,
0
,
0
);
c_thread_copy
.
MoveDstSliceWindow
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
nrepeat_step_minus
);
constexpr
auto
blk_off
=
c_mr_nr_blk_desc
.
CalculateOffset
(
c_thread_idx_
);
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
[
Number
<
blk_off
>
{}].
template
AsType
<
FloatAcc
>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
};
static_assert
((
MRepeat
==
4
&&
NRepeat
==
4
)
or
(
MRepeat
==
4
&&
NRepeat
==
2
)
or
(
MRepeat
==
2
&&
NRepeat
==
4
)
or
(
MRepeat
==
2
&&
NRepeat
==
2
)
or
(
MRepeat
==
2
&&
NRepeat
==
1
)
or
(
MRepeat
==
1
&&
NRepeat
==
2
)
or
(
MRepeat
==
1
&&
NRepeat
==
1
),
"wrong"
);
if
constexpr
(
MRepeat
==
4
&&
NRepeat
==
4
)
{
init_copy
(
make_tuple
(
I0
,
I0
));
if
constexpr
(
CAccessOrderMRepeatNRepeat
)
{
nrepeat_plus_copy
(
make_tuple
(
I0
,
I1
));
nrepeat_plus_copy
(
make_tuple
(
I0
,
I2
));
nrepeat_plus_copy
(
make_tuple
(
I0
,
I3
));
mrepeat_plus_copy
(
make_tuple
(
I1
,
I3
));
nrepeat_minus_copy
(
make_tuple
(
I1
,
I2
));
nrepeat_minus_copy
(
make_tuple
(
I1
,
I1
));
nrepeat_minus_copy
(
make_tuple
(
I1
,
I0
));
mrepeat_plus_copy
(
make_tuple
(
I2
,
I0
));
nrepeat_plus_copy
(
make_tuple
(
I2
,
I1
));
nrepeat_plus_copy
(
make_tuple
(
I2
,
I2
));
nrepeat_plus_copy
(
make_tuple
(
I2
,
I3
));
mrepeat_plus_copy
(
make_tuple
(
I3
,
I3
));
nrepeat_minus_copy
(
make_tuple
(
I3
,
I2
));
nrepeat_minus_copy
(
make_tuple
(
I3
,
I1
));
nrepeat_minus_copy
(
make_tuple
(
I3
,
I0
));
}
else
{
mrepeat_plus_copy
(
make_tuple
(
I1
,
I0
));
mrepeat_plus_copy
(
make_tuple
(
I2
,
I0
));
mrepeat_plus_copy
(
make_tuple
(
I3
,
I0
));
nrepeat_plus_copy
(
make_tuple
(
I3
,
I1
));
mrepeat_minus_copy
(
make_tuple
(
I2
,
I1
));
mrepeat_minus_copy
(
make_tuple
(
I1
,
I1
));
mrepeat_minus_copy
(
make_tuple
(
I0
,
I1
));
nrepeat_plus_copy
(
make_tuple
(
I0
,
I2
));
mrepeat_plus_copy
(
make_tuple
(
I1
,
I2
));
mrepeat_plus_copy
(
make_tuple
(
I2
,
I2
));
mrepeat_plus_copy
(
make_tuple
(
I3
,
I2
));
nrepeat_plus_copy
(
make_tuple
(
I3
,
I3
));
mrepeat_minus_copy
(
make_tuple
(
I2
,
I3
));
mrepeat_minus_copy
(
make_tuple
(
I1
,
I3
));
mrepeat_minus_copy
(
make_tuple
(
I0
,
I3
));
}
}
else
if
constexpr
(
MRepeat
==
4
&&
NRepeat
==
2
)
{
init_copy
(
make_tuple
(
I0
,
I0
));
if
constexpr
(
CAccessOrderMRepeatNRepeat
)
{
nrepeat_plus_copy
(
make_tuple
(
I0
,
I1
));
mrepeat_plus_copy
(
make_tuple
(
I1
,
I1
));
nrepeat_minus_copy
(
make_tuple
(
I1
,
I0
));
mrepeat_plus_copy
(
make_tuple
(
I2
,
I0
));
nrepeat_plus_copy
(
make_tuple
(
I2
,
I1
));
mrepeat_plus_copy
(
make_tuple
(
I3
,
I1
));
nrepeat_minus_copy
(
make_tuple
(
I3
,
I0
));
}
else
{
mrepeat_plus_copy
(
make_tuple
(
I1
,
I0
));
mrepeat_plus_copy
(
make_tuple
(
I2
,
I0
));
mrepeat_plus_copy
(
make_tuple
(
I3
,
I0
));
nrepeat_plus_copy
(
make_tuple
(
I3
,
I1
));
mrepeat_minus_copy
(
make_tuple
(
I2
,
I1
));
mrepeat_minus_copy
(
make_tuple
(
I1
,
I1
));
mrepeat_minus_copy
(
make_tuple
(
I0
,
I1
));
}
}
else
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
4
)
{
init_copy
(
make_tuple
(
I0
,
I0
));
if
constexpr
(
CAccessOrderMRepeatNRepeat
)
{
nrepeat_plus_copy
(
make_tuple
(
I0
,
I1
));
nrepeat_plus_copy
(
make_tuple
(
I0
,
I2
));
nrepeat_plus_copy
(
make_tuple
(
I0
,
I3
));
mrepeat_plus_copy
(
make_tuple
(
I1
,
I3
));
nrepeat_minus_copy
(
make_tuple
(
I1
,
I2
));
nrepeat_minus_copy
(
make_tuple
(
I1
,
I1
));
nrepeat_minus_copy
(
make_tuple
(
I1
,
I0
));
}
else
{
mrepeat_plus_copy
(
make_tuple
(
I1
,
I0
));
nrepeat_plus_copy
(
make_tuple
(
I1
,
I1
));
mrepeat_minus_copy
(
make_tuple
(
I0
,
I1
));
nrepeat_plus_copy
(
make_tuple
(
I0
,
I2
));
mrepeat_plus_copy
(
make_tuple
(
I1
,
I2
));
nrepeat_plus_copy
(
make_tuple
(
I1
,
I3
));
mrepeat_minus_copy
(
make_tuple
(
I0
,
I3
));
}
}
else
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
init_copy
(
make_tuple
(
I0
,
I0
));
if
constexpr
(
CAccessOrderMRepeatNRepeat
)
{
nrepeat_plus_copy
(
make_tuple
(
I0
,
I1
));
mrepeat_plus_copy
(
make_tuple
(
I1
,
I1
));
nrepeat_minus_copy
(
make_tuple
(
I1
,
I0
));
}
else
{
mrepeat_plus_copy
(
make_tuple
(
I1
,
I0
));
nrepeat_plus_copy
(
make_tuple
(
I1
,
I1
));
mrepeat_minus_copy
(
make_tuple
(
I0
,
I1
));
}
}
else
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
1
)
{
init_copy
(
make_tuple
(
I0
,
I0
));
mrepeat_plus_copy
(
make_tuple
(
I1
,
I0
));
}
else
if
constexpr
(
MRepeat
==
1
&&
NRepeat
==
2
)
{
init_copy
(
make_tuple
(
I0
,
I0
));
nrepeat_plus_copy
(
make_tuple
(
I0
,
I1
));
}
else
if
constexpr
(
MRepeat
==
1
&&
NRepeat
==
1
)
{
init_copy
(
make_tuple
(
I0
,
I0
));
}
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
])};
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
}
}
};
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
0 → 100644
View file @
38a90b6e
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
ABK0MK1GridDesc
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_b_k0_m_k1_grid_desc
,
const
void
CONSTANT
*
p_b_b_k0_n_k1_grid_desc
,
const
void
CONSTANT
*
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
void
CONSTANT
*
p_c_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
const
auto
a_b_k0_m_k1_grid_desc
=
*
reinterpret_cast
<
const
ABK0MK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_a_b_k0_m_k1_grid_desc
));
const
auto
b_b_k0_n_k1_grid_desc
=
*
reinterpret_cast
<
const
BBK0NK1GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_b_b_k0_n_k1_grid_desc
));
const
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
=
*
reinterpret_cast
<
const
CM0N0M1N1M2M3M4N2GridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
));
const
auto
c_block_cluster_adaptor
=
*
reinterpret_cast
<
const
CBlockClusterAdaptor
*>
(
cast_pointer_to_generic_address_space
(
p_c_block_cluster_adaptor
));
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#endif
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CMNGridDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
,
bool
ABlockLdsExtraM
,
bool
BBlockLdsExtraN
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k0_n_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
if
(
!
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
&&
K1
==
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
&&
K1
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
&&
KBatch
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
))
return
false
;
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
)
*
KBatch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
>
;
return
BlockwiseGemm
::
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
c_m_n_grid_desc
);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KBatch
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
);
return
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
;
}
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
,
1
));
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_b_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
KPerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
b_b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
KPerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_b_k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_b_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_b_k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
BGridStepHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hack
=
BGridMoveSliceWindowStepHacks
{};
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
// preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_b_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
}
// main body
index_t
k_block_data_begin
=
0
;
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_b_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_b_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
k_block_data_begin
+=
KPerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2BlockDescriptor
();
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
M0
>
{},
Number
<
N0
>
{},
I1
,
I1
,
Number
<
M2
>
{},
I1
,
Number
<
M4
>
{},
I1
));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
),
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
])};
c_thread_copy
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
}
}
};
// namespace ck
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
0 → 100644
View file @
38a90b6e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_blockwise.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
srcDataType
,
typename
dstDataType
,
typename
compType
,
typename
src2dDescType
,
typename
dst1dDescType
,
ReduceTensorOp_t
op
,
NanPropagation_t
nanPropaOpt
,
ReduceTensorIndices_t
reduceIndicesOpt
,
bool
isFirstCall
,
bool
isLastCall
,
index_t
GredAccessesPerThreadInBlock
>
struct
GridwiseReduction_xy_to_x_blockwise
{
using
opReduce
=
typename
reduce_binary_operator
<
compType
,
op
>::
opType
;
using
preUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
isFirstCall
,
isLastCall
>::
preUnaryOp
;
using
posUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
isFirstCall
,
isLastCall
>::
posUnaryOp
;
static
constexpr
auto
buffer2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
GredAccessesPerThreadInBlock
>
{},
Number
<
BlockSize
>
{}));
using
blockwise_reduce
=
BlockwiseReduction_2d_block_buffer
<
decltype
(
buffer2dDesc
),
true
,
opReduce
,
nanPropaOpt
>
;
static
constexpr
index_t
BlockBufferSize
=
buffer2dDesc
.
GetElementSize
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
int
RunId
>
__device__
static
void
Run
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
);
template
<
>
__device__
static
void
Run
<
1
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
ws_indices_global
;
(
void
)
indices_global
;
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
in_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_buffer
,
BlockBufferSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
const
posUnaryOpType
posUnaryOp
(
divider
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
constexpr
auto
in_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
BlockBufferSize
>
{}));
using
ThreadSliceLengths
=
Sequence
<
1
,
GredAccessesPerThreadInBlock
>
;
using
ThreadClusterLengths
=
Sequence
<
1
,
BlockSize
>
;
auto
blockwise_src_load
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
BlockBufferSize
>
,
ThreadSliceLengths
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
srcDataType
,
compType
,
src2dDescType
,
decltype
(
in_block_desc
),
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
,
true
>
(
src2dDesc
,
make_multi_index
(
block_global_1d_id
,
0
),
in_block_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
in_block_copy_step
=
make_multi_index
(
0
,
BlockBufferSize
);
const
index_t
toReduceBlocks
=
(
toReduceLength
+
BlockSize
-
1
)
/
BlockSize
;
for
(
index_t
reducedBlocks
=
0
;
reducedBlocks
<
toReduceBlocks
;
reducedBlocks
+=
GredAccessesPerThreadInBlock
)
{
blockwise_src_load
.
RunRead
(
src2dDesc
,
src_global_buf
);
blockwise_src_load
.
RunWrite
(
in_block_desc
,
in_block_buf
);
__syncthreads
();
// do element-wise pre-reduction operation
blockwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_block_buf
);
index_t
BlocksInOneOp
=
(
reducedBlocks
<
toReduceBlocks
-
GredAccessesPerThreadInBlock
)
?
GredAccessesPerThreadInBlock
:
toReduceBlocks
-
reducedBlocks
;
blockwise_reduce
::
Reduce
(
in_block_buf
,
BlocksInOneOp
,
accuValue_buf
(
I0
));
blockwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_block_copy_step
);
}
accuValue_buf
(
I0
)
=
posUnaryOp
(
accuValue_buf
[
I0
]);
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
// The first thread in the block stores the reduced result to the global location
// representing the block
if
(
thread_local_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
block_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
block_global_1d_id
));
threadwise_dst_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_buf
);
}
};
template
<
>
__device__
static
void
Run
<
2
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
ws_indices_global
;
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
int
block_indices_buffer
[
BlockBufferSize
];
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
indices_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
in_block_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_buffer
,
BlockBufferSize
);
auto
in_block_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
block_indices_buffer
,
BlockBufferSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
1
,
true
>
accuIndex_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
accuIndex_buf
(
I0
)
=
0
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
constexpr
auto
in_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
BlockBufferSize
>
{}));
using
ThreadSliceLengths
=
Sequence
<
1
,
GredAccessesPerThreadInBlock
>
;
using
ThreadClusterLengths
=
Sequence
<
1
,
BlockSize
>
;
auto
blockwise_src_load
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
BlockBufferSize
>
,
ThreadSliceLengths
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
srcDataType
,
compType
,
src2dDescType
,
decltype
(
in_block_desc
),
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
,
true
>
(
src2dDesc
,
make_multi_index
(
block_global_1d_id
,
0
),
in_block_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
in_block_copy_step
=
make_multi_index
(
0
,
BlockBufferSize
);
const
index_t
toReduceBlocks
=
(
toReduceLength
+
BlockSize
-
1
)
/
BlockSize
;
int
indexOffset
=
0
;
for
(
index_t
reducedBlocks
=
0
;
reducedBlocks
<
toReduceBlocks
;
reducedBlocks
+=
GredAccessesPerThreadInBlock
)
{
// load block data from global to LDS, no use of double buffers (to be improved)
blockwise_src_load
.
RunRead
(
src2dDesc
,
src_global_buf
);
blockwise_src_load
.
RunWrite
(
in_block_desc
,
in_block_val_buf
);
__syncthreads
();
// construct the indices for the current toReduce blocks
blockwise_reduce
::
init_buffer_indices
(
in_block_idx_buf
,
indexOffset
);
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
blockwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_block_val_buf
);
index_t
BlocksInOneOp
=
(
reducedBlocks
<
toReduceBlocks
-
GredAccessesPerThreadInBlock
)
?
GredAccessesPerThreadInBlock
:
toReduceBlocks
-
reducedBlocks
;
blockwise_reduce
::
Reduce2
(
in_block_val_buf
,
in_block_idx_buf
,
BlocksInOneOp
,
accuValue_buf
(
I0
),
accuIndex_buf
(
I0
));
indexOffset
+=
BlockBufferSize
;
blockwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_block_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
// The first thread in the block stores the reduced result to the global location
// representing the block
if
(
thread_local_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
block_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
block_global_1d_id
));
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int
,
int
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
block_global_1d_id
));
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
}
};
template
<
>
__device__
static
void
Run
<
3
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
ws_values_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
origReduceLen
;
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
__shared__
int
block_indices_buffer
[
BlockBufferSize
];
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_indices_global
,
src2dDesc
.
GetElementSpaceSize
());
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
indices_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
in_block_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_buffer
,
BlockBufferSize
);
auto
in_block_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
block_indices_buffer
,
BlockBufferSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
1
,
true
>
accuIndex_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
accuIndex_buf
(
I0
)
=
0
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_1d_id
=
get_block_1d_id
();
constexpr
auto
in_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
BlockBufferSize
>
{}));
using
ThreadSliceLengths
=
Sequence
<
1
,
GredAccessesPerThreadInBlock
>
;
using
ThreadClusterLengths
=
Sequence
<
1
,
BlockSize
>
;
auto
blockwise_src_val_load
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
BlockBufferSize
>
,
ThreadSliceLengths
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
srcDataType
,
compType
,
src2dDescType
,
decltype
(
in_block_desc
),
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
,
true
>
(
src2dDesc
,
make_multi_index
(
block_global_1d_id
,
0
),
in_block_desc
,
make_multi_index
(
0
,
0
));
auto
blockwise_src_idx_load
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
BlockBufferSize
>
,
ThreadSliceLengths
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
int
,
int
,
src2dDescType
,
decltype
(
in_block_desc
),
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
,
true
>
(
src2dDesc
,
make_multi_index
(
block_global_1d_id
,
0
),
in_block_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
in_block_copy_step
=
make_multi_index
(
0
,
BlockBufferSize
);
const
index_t
toReduceBlocks
=
(
toReduceLength
+
BlockSize
-
1
)
/
BlockSize
;
for
(
index_t
reducedBlocks
=
0
;
reducedBlocks
<
toReduceBlocks
;
reducedBlocks
+=
GredAccessesPerThreadInBlock
)
{
// load block data from global to LDS, no use of double buffers (to be improved)
blockwise_src_val_load
.
RunRead
(
src2dDesc
,
src_global_val_buf
);
blockwise_src_idx_load
.
RunRead
(
src2dDesc
,
src_global_idx_buf
);
blockwise_src_val_load
.
RunWrite
(
in_block_desc
,
in_block_val_buf
);
blockwise_src_idx_load
.
RunWrite
(
in_block_desc
,
in_block_idx_buf
);
__syncthreads
();
index_t
BlocksInOneOp
=
(
reducedBlocks
<
toReduceBlocks
-
GredAccessesPerThreadInBlock
)
?
GredAccessesPerThreadInBlock
:
toReduceBlocks
-
reducedBlocks
;
blockwise_reduce
::
Reduce2
(
in_block_val_buf
,
in_block_idx_buf
,
BlocksInOneOp
,
accuValue_buf
(
I0
),
accuIndex_buf
(
I0
));
blockwise_src_val_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_block_copy_step
);
blockwise_src_idx_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_block_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
// The first thread in the block stores the reduced result to the global location
// representing the block
if
(
thread_local_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
block_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
block_global_1d_id
));
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int
,
int
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
block_global_1d_id
));
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
}
};
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
0 → 100644
View file @
38a90b6e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_THREADWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
srcDataType
,
typename
dstDataType
,
typename
compType
,
typename
src2dDescType
,
typename
dst1dDescType
,
ReduceTensorOp_t
op
,
NanPropagation_t
nanPropaOpt
,
ReduceTensorIndices_t
reduceIndicesOpt
,
bool
isFirstCall
,
bool
isLastCall
,
index_t
GredThreadBufferLength
>
struct
GridwiseReduction_xy_to_x_direct_threadwise
{
using
opReduce
=
typename
reduce_binary_operator
<
compType
,
op
>::
opType
;
using
preUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
isFirstCall
,
isLastCall
>::
preUnaryOp
;
using
posUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
isFirstCall
,
isLastCall
>::
posUnaryOp
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
int
RunId
>
__device__
static
void
Run
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
);
template
<
>
__device__
static
void
Run
<
1
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
ws_indices_global
;
(
void
)
indices_global
;
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
GredThreadBufferLength
,
true
>
in_thread_buf
;
using
threadwise_reduce
=
ThreadReduce
<
decltype
(
in_thread_buf
),
opReduce
,
nanPropaOpt
>
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
const
posUnaryOpType
posUnaryOp
(
divider
);
using
ThreadBufferLengths
=
Sequence
<
1
,
GredThreadBufferLength
>
;
constexpr
auto
ThreadBufferDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
GredThreadBufferLength
>
{}));
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
srcDataType
,
compType
,
src2dDescType
,
decltype
(
ThreadBufferDesc
),
ThreadBufferLengths
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
src2dDesc
,
make_multi_index
(
thread_global_1d_id
,
0
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
GredThreadBufferLength
);
for
(
index_t
reducedLength
=
0
;
reducedLength
<
toReduceLength
;
reducedLength
+=
GredThreadBufferLength
)
{
threadwise_src_load
.
Run
(
src2dDesc
,
src_global_buf
,
ThreadBufferDesc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
// do element-wise pre-reduction operation
threadwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_thread_buf
);
// do the reduction on the Thread Buffer
threadwise_reduce
::
Reduce
(
in_thread_buf
,
accuValue_buf
(
I0
));
threadwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_thread_copy_step
);
}
accuValue_buf
(
I0
)
=
posUnaryOp
(
accuValue_buf
[
I0
]);
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
thread_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
thread_global_1d_id
));
threadwise_dst_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_buf
);
};
template
<
>
__device__
static
void
Run
<
2
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
ws_indices_global
;
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
indices_global
,
dst1dDesc
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
GredThreadBufferLength
,
true
>
in_thread_buf
;
using
threadwise_reduce
=
ThreadReduce
<
decltype
(
in_thread_buf
),
opReduce
,
nanPropaOpt
>
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
1
,
true
>
accuIndex_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
accuIndex_buf
(
I0
)
=
0
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
using
ThreadBufferLengths
=
Sequence
<
1
,
GredThreadBufferLength
>
;
constexpr
auto
ThreadBufferDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
GredThreadBufferLength
>
{}));
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
srcDataType
,
compType
,
src2dDescType
,
decltype
(
ThreadBufferDesc
),
ThreadBufferLengths
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
src2dDesc
,
make_multi_index
(
thread_global_1d_id
,
0
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
GredThreadBufferLength
);
index_t
indexStart
=
0
;
for
(
index_t
reducedLength
=
0
;
reducedLength
<
toReduceLength
;
reducedLength
+=
GredThreadBufferLength
)
{
threadwise_src_load
.
Run
(
src2dDesc
,
src_global_buf
,
ThreadBufferDesc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
threadwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_thread_buf
);
// do the reduction on the Thread Buffer
threadwise_reduce
::
Reduce2
(
in_thread_buf
,
accuValue_buf
(
I0
),
accuIndex_buf
(
I0
),
indexStart
);
indexStart
+=
GredThreadBufferLength
;
threadwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_thread_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
thread_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
thread_global_1d_id
));
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int
,
int
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
thread_global_1d_id
));
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
};
template
<
>
__device__
static
void
Run
<
3
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
ws_values_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
origReduceLen
;
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_indices_global
,
src2dDesc
.
GetElementSpaceSize
());
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
indices_global
,
dst1dDesc
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
GredThreadBufferLength
,
true
>
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
GredThreadBufferLength
,
true
>
in_thread_idx_buf
;
using
threadwise_reduce
=
ThreadReduceWithIndicesInput
<
decltype
(
in_thread_val_buf
),
decltype
(
in_thread_idx_buf
),
opReduce
,
nanPropaOpt
>
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
1
,
true
>
accuIndex_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
accuIndex_buf
(
I0
)
=
0
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
using
ThreadBufferLengths
=
Sequence
<
1
,
GredThreadBufferLength
>
;
constexpr
auto
ThreadBufferDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
GredThreadBufferLength
>
{}));
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
auto
threadwise_src_val_load
=
ThreadwiseTensorSliceTransfer_v2
<
srcDataType
,
compType
,
src2dDescType
,
decltype
(
ThreadBufferDesc
),
ThreadBufferLengths
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
src2dDesc
,
make_multi_index
(
thread_global_1d_id
,
0
));
auto
threadwise_src_idx_load
=
ThreadwiseTensorSliceTransfer_v2
<
int
,
int
,
src2dDescType
,
decltype
(
ThreadBufferDesc
),
ThreadBufferLengths
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
src2dDesc
,
make_multi_index
(
thread_global_1d_id
,
0
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
GredThreadBufferLength
);
for
(
index_t
reducedLength
=
0
;
reducedLength
<
toReduceLength
;
reducedLength
+=
GredThreadBufferLength
)
{
threadwise_src_val_load
.
Run
(
src2dDesc
,
src_global_val_buf
,
ThreadBufferDesc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
threadwise_src_idx_load
.
Run
(
src2dDesc
,
src_global_idx_buf
,
ThreadBufferDesc
,
make_tuple
(
I0
,
I0
),
in_thread_idx_buf
);
// do the reduction on the Thread Buffer
threadwise_reduce
::
Reduce
(
in_thread_val_buf
,
in_thread_idx_buf
,
accuValue_buf
(
I0
),
accuIndex_buf
(
I0
));
threadwise_src_val_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_thread_copy_step
);
threadwise_src_idx_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_thread_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
thread_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
thread_global_1d_id
));
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int
,
int
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
false
>
(
dst1dDesc
,
make_multi_index
(
thread_global_1d_id
));
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
};
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
0 → 100644
View file @
38a90b6e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_DIRECT_WARPWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_warpwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
srcDataType
,
typename
dstDataType
,
typename
compType
,
typename
src2dDescType
,
typename
dst1dDescType
,
ReduceTensorOp_t
op
,
NanPropagation_t
nanPropaOpt
,
ReduceTensorIndices_t
reduceIndicesOpt
,
bool
isFirstCall
,
bool
isLastCall
,
index_t
GredAccessesPerThreadInWarp
>
struct
GridwiseReduction_xy_to_x_direct_warpwise
{
using
opReduce
=
typename
reduce_binary_operator
<
compType
,
op
>::
opType
;
using
preUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
isFirstCall
,
isLastCall
>::
preUnaryOp
;
using
posUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
isFirstCall
,
isLastCall
>::
posUnaryOp
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
int
RunId
>
__device__
static
void
Run
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
);
template
<
>
__device__
static
void
Run
<
1
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
ws_indices_global
;
(
void
)
indices_global
;
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
dst_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
GredAccessesPerThreadInWarp
,
true
>
in_thread_buf
;
using
warpwise_reduce
=
WarpReduce
<
decltype
(
in_thread_buf
),
BlockSize
,
opReduce
,
nanPropaOpt
>
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
const
posUnaryOpType
posUnaryOp
(
divider
);
using
ThreadBufferLengths
=
Sequence
<
1
,
GredAccessesPerThreadInWarp
>
;
constexpr
auto
ThreadBufferDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
GredAccessesPerThreadInWarp
>
{}));
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
index_t
warp_global_1d_id
=
thread_global_1d_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_global_1d_id
%
warpSize
;
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
srcDataType
,
compType
,
src2dDescType
,
decltype
(
ThreadBufferDesc
),
ThreadBufferLengths
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
src2dDesc
,
make_multi_index
(
warp_global_1d_id
,
thread_inwarp_id
*
GredAccessesPerThreadInWarp
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
warpSize
*
GredAccessesPerThreadInWarp
);
for
(
index_t
reducedLength
=
0
;
reducedLength
<
toReduceLength
;
reducedLength
+=
warpSize
*
GredAccessesPerThreadInWarp
)
{
threadwise_src_load
.
Run
(
src2dDesc
,
src_global_buf
,
ThreadBufferDesc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
// do element-wise pre-reduction operation
warpwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_thread_buf
);
// do the warp-wise reduction on data of all thread buffers
warpwise_reduce
::
Reduce
(
in_thread_buf
,
accuValue_buf
(
I0
));
threadwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_thread_copy_step
);
}
accuValue_buf
(
I0
)
=
posUnaryOp
(
accuValue_buf
[
I0
]);
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
// The first thread in the warp stores the reduced result to the global location
// representing the Warp
if
(
thread_inwarp_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
warp_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
(
I0
)
*
beta
;
}
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
warp_global_1d_id
));
threadwise_dst_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_buf
);
}
};
template
<
>
__device__
static
void
Run
<
2
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
ws_indices_global
;
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
indices_global
,
dst1dDesc
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
GredAccessesPerThreadInWarp
,
true
>
in_thread_buf
;
using
warpwise_reduce
=
WarpReduce
<
decltype
(
in_thread_buf
),
BlockSize
,
opReduce
,
nanPropaOpt
>
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
1
,
true
>
accuIndex_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
accuIndex_buf
(
I0
)
=
0
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
using
ThreadBufferLengths
=
Sequence
<
1
,
GredAccessesPerThreadInWarp
>
;
constexpr
auto
ThreadBufferDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
GredAccessesPerThreadInWarp
>
{}));
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
index_t
warp_global_1d_id
=
thread_global_1d_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_global_1d_id
%
warpSize
;
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
srcDataType
,
compType
,
src2dDescType
,
decltype
(
ThreadBufferDesc
),
ThreadBufferLengths
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
src2dDesc
,
make_multi_index
(
warp_global_1d_id
,
thread_inwarp_id
*
GredAccessesPerThreadInWarp
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
warpSize
*
GredAccessesPerThreadInWarp
);
index_t
indexOffset
=
0
;
for
(
index_t
reducedLength
=
0
;
reducedLength
<
toReduceLength
;
reducedLength
+=
warpSize
*
GredAccessesPerThreadInWarp
)
{
threadwise_src_load
.
Run
(
src2dDesc
,
src_global_buf
,
ThreadBufferDesc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
warpwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_thread_buf
);
// do the warp-wise reduction on data of all thread buffers
warpwise_reduce
::
Reduce2
(
in_thread_buf
,
accuValue_buf
(
I0
),
accuIndex_buf
(
I0
),
indexOffset
);
indexOffset
+=
warpSize
*
GredAccessesPerThreadInWarp
;
threadwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_thread_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
// The first thread in the warp stores the reduced result to the global location
// representing the Warp
if
(
thread_inwarp_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
warp_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
warp_global_1d_id
));
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int
,
int
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
warp_global_1d_id
));
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
}
};
template
<
>
__device__
static
void
Run
<
3
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
ws_values_global
,
dstDataType
beta
,
dstDataType
*
const
__restrict__
p_dst_global
,
const
int
*
const
__restrict__
ws_indices_global
,
int
*
const
__restrict__
indices_global
)
{
(
void
)
origReduceLen
;
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
const
auto
src_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
const
auto
src_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_indices_global
,
src2dDesc
.
GetElementSpaceSize
());
auto
dst_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_dst_global
,
dst1dDesc
.
GetElementSpaceSize
());
auto
dst_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
indices_global
,
dst1dDesc
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
GredAccessesPerThreadInWarp
,
true
>
in_thread_val_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
GredAccessesPerThreadInWarp
,
true
>
in_thread_idx_buf
;
using
warpwise_reduce
=
WarpReduceWithIndicesInput
<
decltype
(
in_thread_val_buf
),
decltype
(
in_thread_idx_buf
),
BlockSize
,
opReduce
,
nanPropaOpt
>
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
1
,
true
>
accuIndex_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
accuIndex_buf
(
I0
)
=
0
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
using
ThreadBufferLengths
=
Sequence
<
1
,
GredAccessesPerThreadInWarp
>
;
constexpr
auto
ThreadBufferDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
GredAccessesPerThreadInWarp
>
{}));
index_t
thread_global_1d_id
=
get_block_1d_id
()
*
BlockSize
+
get_thread_local_1d_id
();
index_t
warp_global_1d_id
=
thread_global_1d_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_global_1d_id
%
warpSize
;
auto
threadwise_src_val_load
=
ThreadwiseTensorSliceTransfer_v2
<
srcDataType
,
compType
,
src2dDescType
,
decltype
(
ThreadBufferDesc
),
ThreadBufferLengths
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
src2dDesc
,
make_multi_index
(
warp_global_1d_id
,
thread_inwarp_id
*
GredAccessesPerThreadInWarp
));
auto
threadwise_src_idx_load
=
ThreadwiseTensorSliceTransfer_v2
<
int
,
int
,
src2dDescType
,
decltype
(
ThreadBufferDesc
),
ThreadBufferLengths
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
false
>
(
src2dDesc
,
make_multi_index
(
warp_global_1d_id
,
thread_inwarp_id
*
GredAccessesPerThreadInWarp
));
constexpr
auto
in_thread_copy_step
=
make_multi_index
(
0
,
warpSize
*
GredAccessesPerThreadInWarp
);
for
(
index_t
reducedLength
=
0
;
reducedLength
<
toReduceLength
;
reducedLength
+=
warpSize
*
GredAccessesPerThreadInWarp
)
{
threadwise_src_val_load
.
Run
(
src2dDesc
,
src_global_val_buf
,
ThreadBufferDesc
,
make_tuple
(
I0
,
I0
),
in_thread_val_buf
);
threadwise_src_idx_load
.
Run
(
src2dDesc
,
src_global_idx_buf
,
ThreadBufferDesc
,
make_tuple
(
I0
,
I0
),
in_thread_idx_buf
);
// do the warp-wise reduction on data of all thread buffers
warpwise_reduce
::
Reduce
(
in_thread_val_buf
,
in_thread_idx_buf
,
accuValue_buf
(
I0
),
accuIndex_buf
(
I0
));
threadwise_src_val_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_thread_copy_step
);
threadwise_src_idx_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_thread_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
// The first thread in the warp stores the reduced result to the global location
// representing the Warp
if
(
thread_inwarp_id
==
0
)
{
if
(
!
float_equal_one
{}(
alpha
))
accuValue_buf
(
I0
)
*=
type_convert
<
compType
>
{}(
alpha
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
dstValue_buf
;
dstValue_buf
(
I0
)
=
type_convert
<
dstDataType
>
{}(
accuValue_buf
[
I0
]);
if
(
!
float_equal_zero
{}(
beta
))
{
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
dstDataType
,
dstDataType
,
dst1dDescType
,
decltype
(
ReducedDataDesc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
warp_global_1d_id
));
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
dstDataType
,
1
,
true
>
priorDstValue_buf
;
threadwise_dst_load
.
Run
(
dst1dDesc
,
dst_global_val_buf
,
ReducedDataDesc
,
make_tuple
(
I0
),
priorDstValue_buf
);
dstValue_buf
(
I0
)
+=
priorDstValue_buf
[
I0
]
*
beta
;
}
auto
threadwise_dst_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
dstDataType
,
dstDataType
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
warp_global_1d_id
));
auto
threadwise_dst_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int
,
int
,
decltype
(
ReducedDataDesc
),
dst1dDescType
,
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
dst1dDesc
,
make_multi_index
(
warp_global_1d_id
));
threadwise_dst_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
dstValue_buf
,
dst1dDesc
,
dst_global_val_buf
);
threadwise_dst_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
dst1dDesc
,
dst_global_idx_buf
);
}
};
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
0 → 100644
View file @
38a90b6e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_blockwise.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
srcDataType
,
typename
dstDataType
,
// not used together with the beta input
typename
compType
,
typename
src2dDescType
,
typename
dst1dDescType
,
ReduceTensorOp_t
op
,
NanPropagation_t
nanPropaOpt
,
ReduceTensorIndices_t
reduceIndicesOpt
,
index_t
GredAccessesPerThreadInBlock
>
struct
GridwiseReduction_xy_to_x_multiblock
{
using
opReduce
=
typename
reduce_binary_operator
<
compType
,
op
>::
opType
;
using
preUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
true
,
false
>::
preUnaryOp
;
using
posUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
true
,
false
>::
posUnaryOp
;
static
constexpr
auto
buffer2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
GredAccessesPerThreadInBlock
>
{},
Number
<
BlockSize
>
{}));
using
blockwise_reduce
=
BlockwiseReduction_2d_block_buffer
<
decltype
(
buffer2dDesc
),
true
,
opReduce
,
nanPropaOpt
>
;
static
constexpr
index_t
BlockBufferSize
=
buffer2dDesc
.
GetElementSize
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
int
RunId
>
__device__
static
void
Run
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
int
BlkGroupSize
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
srcDataType
*
const
__restrict__
ws_values_global
,
int
*
const
__restrict__
ws_indices_global
);
template
<
>
__device__
static
void
Run
<
1
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
int
BlkGroupSize
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
srcDataType
*
const
__restrict__
ws_values_global
,
int
*
const
__restrict__
ws_indices_global
)
{
(
void
)
ws_indices_global
;
(
void
)
alpha
;
// unused
(
void
)
beta
;
// unused
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
workspace_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
);
auto
in_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_buffer
,
BlockBufferSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
BlkGroupSize
;
const
index_t
block_local_id
=
block_global_id
%
BlkGroupSize
;
const
index_t
reduceSizePerBlock
=
(((
toReduceLength
+
BlkGroupSize
-
1
)
/
BlkGroupSize
+
BlockBufferSize
-
1
)
/
BlockBufferSize
)
*
BlockBufferSize
;
constexpr
auto
in_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
BlockSize
*
GredAccessesPerThreadInBlock
>
{}));
using
ThreadSliceLengths
=
Sequence
<
1
,
GredAccessesPerThreadInBlock
>
;
using
ThreadClusterLengths
=
Sequence
<
1
,
BlockSize
>
;
auto
blockwise_src_load
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
BlockBufferSize
>
,
ThreadSliceLengths
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
srcDataType
,
compType
,
src2dDescType
,
decltype
(
in_block_desc
),
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
,
true
>
(
src2dDesc
,
make_multi_index
(
blkgroup_id
,
block_local_id
*
reduceSizePerBlock
),
in_block_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
in_block_copy_step
=
make_multi_index
(
0
,
BlockBufferSize
);
const
index_t
toReduceBlocks
=
(
reduceSizePerBlock
+
BlockSize
-
1
)
/
BlockSize
;
for
(
index_t
reducedBlocks
=
0
;
reducedBlocks
<
toReduceBlocks
;
reducedBlocks
+=
GredAccessesPerThreadInBlock
)
{
blockwise_src_load
.
RunRead
(
src2dDesc
,
src_global_buf
);
blockwise_src_load
.
RunWrite
(
in_block_desc
,
in_block_buf
);
__syncthreads
();
// do element-wise pre-reduction operation
blockwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_block_buf
);
index_t
BlocksInOneOp
=
(
reducedBlocks
<
toReduceBlocks
-
GredAccessesPerThreadInBlock
)
?
GredAccessesPerThreadInBlock
:
toReduceBlocks
-
reducedBlocks
;
blockwise_reduce
::
Reduce
(
in_block_buf
,
BlocksInOneOp
,
accuValue_buf
(
I0
));
blockwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_block_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
const
auto
workspace_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
));
// The first thread in the block stores the reduced result to the global location
// representing the block
if
(
thread_local_id
==
0
)
{
auto
threadwise_workspace_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
compType
,
srcDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
workspace_desc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
workspace_desc
,
make_multi_index
(
block_global_id
));
threadwise_workspace_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuValue_buf
,
workspace_desc
,
workspace_global_buf
);
}
};
template
<
>
__device__
static
void
Run
<
2
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
int
BlkGroupSize
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
srcDataType
*
const
__restrict__
ws_values_global
,
int
*
const
__restrict__
ws_indices_global
)
{
(
void
)
alpha
;
// unused
(
void
)
beta
;
// unused
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
// LDS
__shared__
compType
p_in_block_values_buffer
[
BlockBufferSize
];
__shared__
int
p_in_block_indices_buffer
[
BlockBufferSize
];
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
workspace_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
);
auto
workspace_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_indices_global
,
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
);
auto
in_block_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_values_buffer
,
BlockBufferSize
);
auto
in_block_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_indices_buffer
,
BlockBufferSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
1
,
true
>
accuIndex_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
accuIndex_buf
(
I0
)
=
0
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
BlkGroupSize
;
const
index_t
block_local_id
=
block_global_id
%
BlkGroupSize
;
const
index_t
reduceSizePerBlock
=
(((
toReduceLength
+
BlkGroupSize
-
1
)
/
BlkGroupSize
+
BlockBufferSize
-
1
)
/
BlockBufferSize
)
*
BlockBufferSize
;
constexpr
auto
in_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
BlockSize
*
GredAccessesPerThreadInBlock
>
{}));
using
ThreadSliceLengths
=
Sequence
<
1
,
GredAccessesPerThreadInBlock
>
;
using
ThreadClusterLengths
=
Sequence
<
1
,
BlockSize
>
;
auto
blockwise_src_load
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
BlockBufferSize
>
,
ThreadSliceLengths
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
srcDataType
,
compType
,
src2dDescType
,
decltype
(
in_block_desc
),
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
,
true
>
(
src2dDesc
,
make_multi_index
(
blkgroup_id
,
block_local_id
*
reduceSizePerBlock
),
in_block_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
in_block_copy_step
=
make_multi_index
(
0
,
BlockBufferSize
);
const
index_t
toReduceBlocks
=
(
reduceSizePerBlock
+
BlockSize
-
1
)
/
BlockSize
;
int
indexOffset
=
block_local_id
*
reduceSizePerBlock
;
for
(
index_t
reducedBlocks
=
0
;
reducedBlocks
<
toReduceBlocks
;
reducedBlocks
+=
GredAccessesPerThreadInBlock
)
{
blockwise_reduce
::
init_buffer_indices
(
in_block_idx_buf
,
indexOffset
);
blockwise_src_load
.
RunRead
(
src2dDesc
,
src_global_buf
);
blockwise_src_load
.
RunWrite
(
in_block_desc
,
in_block_val_buf
);
__syncthreads
();
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
blockwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_block_val_buf
);
index_t
BlocksInOneOp
=
(
reducedBlocks
<
toReduceBlocks
-
GredAccessesPerThreadInBlock
)
?
GredAccessesPerThreadInBlock
:
toReduceBlocks
-
reducedBlocks
;
blockwise_reduce
::
Reduce2
(
in_block_val_buf
,
in_block_idx_buf
,
BlocksInOneOp
,
accuValue_buf
(
I0
),
accuIndex_buf
(
I0
));
indexOffset
+=
BlockBufferSize
;
blockwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_block_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
const
auto
workspace_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
));
// The first thread in the block stores the reduced result to the global location
// representing the block
if
(
thread_local_id
==
0
)
{
auto
threadwise_workspace_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
compType
,
srcDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
workspace_desc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
workspace_desc
,
make_multi_index
(
block_global_id
));
auto
threadwise_workspace_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int
,
int
,
decltype
(
ReducedDataDesc
),
decltype
(
workspace_desc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
workspace_desc
,
make_multi_index
(
block_global_id
));
threadwise_workspace_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuValue_buf
,
workspace_desc
,
workspace_global_val_buf
);
threadwise_workspace_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
workspace_desc
,
workspace_global_idx_buf
);
}
};
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp
0 → 100644
View file @
38a90b6e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace
ck
{
template
<
typename
buffer2dDescType
,
bool
blockIsOneRow
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
BlockwiseReduction_2d_block_buffer
{
using
compType
=
typename
opReduce
::
dataType
;
static
constexpr
auto
buffer2dDesc
=
buffer2dDescType
{};
static
constexpr
index_t
BlockSize
=
blockIsOneRow
?
buffer2dDesc
.
GetLength
(
Number
<
1
>
{})
:
buffer2dDesc
.
GetLength
(
Number
<
0
>
{});
static
constexpr
index_t
NumBlocks
=
blockIsOneRow
?
buffer2dDesc
.
GetLength
(
Number
<
0
>
{})
:
buffer2dDesc
.
GetLength
(
Number
<
1
>
{});
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
// This interface does not accumulate on indices
template
<
typename
BufferType
>
__device__
static
void
Reduce
(
BufferType
&
block_buffer
,
index_t
toReduceBlocks
,
compType
&
accuData
)
{
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
index_t
offset
;
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_local_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
otherDimInd
));
compType
opData
=
type_convert
<
compType
>
{}(
block_buffer
[
offset
]);
binop
::
calculate
(
lAccuData
,
opData
);
}
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
0
,
thread_local_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
0
));
block_buffer
(
offset
)
=
lAccuData
;
__syncthreads
();
for
(
index_t
indOffset
=
BlockSize
/
2
;
indOffset
>
0
;
indOffset
/=
2
)
{
if
(
thread_local_id
<
indOffset
)
{
index_t
offset1
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
0
,
thread_local_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
0
));
index_t
offset2
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
0
,
thread_local_id
+
indOffset
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
+
indOffset
,
0
));
compType
opData1
=
type_convert
<
compType
>
{}(
block_buffer
[
offset1
]);
compType
opData2
=
type_convert
<
compType
>
{}(
block_buffer
[
offset2
]);
binop
::
calculate
(
opData1
,
opData2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}(
opData1
);
}
__syncthreads
();
}
if
(
thread_local_id
==
0
)
{
compType
tmpVal
=
type_convert
<
compType
>
{}(
block_buffer
[
0
]);
binop
::
calculate
(
accuData
,
tmpVal
);
}
};
// This interface accumulates on both data values and indices
template
<
typename
BufferType
,
typename
IdxBufferType
>
__device__
static
void
Reduce2
(
BufferType
&
block_buffer
,
IdxBufferType
&
block_indices_buffer
,
index_t
toReduceBlocks
,
compType
&
accuData
,
int
&
accuIndex
)
{
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
if
constexpr
(
blockIsOneRow
)
{
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
for
(
index_t
indOffset
=
1
;
indOffset
<
BlockSize
;
indOffset
*=
2
)
{
if
(
thread_local_id
%
(
indOffset
*
2
)
==
0
)
{
index_t
offset1
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_local_id
));
index_t
offset2
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_local_id
+
indOffset
));
compType
currVal1
=
type_convert
<
compType
>
{}(
block_buffer
[
offset1
]);
compType
currVal2
=
type_convert
<
compType
>
{}(
block_buffer
[
offset2
]);
int
currIndex1
=
block_indices_buffer
[
offset1
];
int
currIndex2
=
block_indices_buffer
[
offset2
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}(
currVal1
);
block_indices_buffer
(
offset1
)
=
currIndex1
;
}
__syncthreads
();
}
}
if
(
thread_local_id
==
0
)
{
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
index_t
offset
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
0
));
compType
tmpVal
=
type_convert
<
compType
>
{}(
block_buffer
[
offset
]);
int
tmpIndex
=
block_indices_buffer
[
offset
];
binop
::
calculate
(
lAccuData
,
tmpVal
,
lAccuIndex
,
tmpIndex
);
}
binop
::
calculate
(
accuData
,
lAccuData
,
accuIndex
,
lAccuIndex
);
}
}
else
{
index_t
offset
;
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
offset
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
otherDimInd
));
compType
currVal
=
type_convert
<
compType
>
{}(
block_buffer
[
offset
]);
int
currIndex
=
block_indices_buffer
[
offset
];
binop
::
calculate
(
lAccuData
,
currVal
,
lAccuIndex
,
currIndex
);
}
offset
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
0
));
block_buffer
(
offset
)
=
lAccuData
;
block_indices_buffer
(
offset
)
=
lAccuIndex
;
__syncthreads
();
for
(
index_t
indOffset
=
1
;
indOffset
<
BlockSize
;
indOffset
*=
2
)
{
if
(
thread_local_id
%
(
indOffset
*
2
)
==
0
)
{
index_t
offset1
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
0
));
index_t
offset2
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
+
indOffset
,
0
));
compType
currVal1
=
type_convert
<
compType
>
{}(
block_buffer
[
offset1
]);
compType
currVal2
=
type_convert
<
compType
>
{}(
block_buffer
[
offset2
]);
int
currIndex1
=
block_indices_buffer
[
offset1
];
int
currIndex2
=
block_indices_buffer
[
offset2
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}(
currVal1
);
block_indices_buffer
(
offset1
)
=
currIndex1
;
}
__syncthreads
();
}
if
(
thread_local_id
==
0
)
{
compType
tmpVal
=
type_convert
<
compType
>
{}(
block_buffer
[
0
]);
int
tmpIndex
=
block_indices_buffer
[
0
];
binop
::
calculate
(
accuData
,
tmpVal
,
accuIndex
,
tmpIndex
);
}
}
};
template
<
typename
BufferType
>
__device__
static
void
set_buffer_value
(
BufferType
&
block_buffer
,
compType
value
)
{
index_t
thread_id
=
get_thread_local_1d_id
();
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
NumBlocks
;
otherDimInd
++
)
{
index_t
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_id
,
otherDimInd
));
block_buffer
(
offset
)
=
value
;
__syncthreads
();
}
};
// Initialize the block-wise indices buffer, the index for each element in the block-wise data
// buffer
// is calculated according to its position in the buffer and the global starting index
template
<
typename
IdxBufferType
>
__device__
static
void
init_buffer_indices
(
IdxBufferType
&
block_indices_buffer
,
int
indexStart
)
{
index_t
thread_id
=
get_thread_local_1d_id
();
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
NumBlocks
;
otherDimInd
++
)
{
index_t
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_id
,
otherDimInd
));
block_indices_buffer
(
offset
)
=
offset
+
indexStart
;
__syncthreads
();
}
};
// Execute unary operation on the block buffer elements
template
<
typename
unary_op_type
,
typename
BufferType
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
block_buffer
)
{
index_t
thread_id
=
get_thread_local_1d_id
();
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
NumBlocks
;
otherDimInd
++
)
{
index_t
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_id
,
otherDimInd
));
block_buffer
(
offset
)
=
unary_op
(
block_buffer
[
offset
]);
__syncthreads
();
}
};
};
};
// end of namespace ck
#endif
composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp
0 → 100644
View file @
38a90b6e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace
ck
{
template
<
typename
BufferType
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
ThreadReduce
{
using
compType
=
typename
opReduce
::
dataType
;
static_assert
(
BufferType
::
IsStaticBuffer
(),
"Thread-wise reduction needs use StaticBuffer!"
);
static_assert
(
std
::
is_same
<
typename
BufferType
::
type
,
compType
>::
value
,
"Data type of StaticBuffer for Thread-wise reduction should be same as the compType!"
);
static
constexpr
index_t
ThreadBufferLen
=
BufferType
::
Size
();
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
// This interface does not accumulate on indices
__device__
static
void
Reduce
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
binop
::
calculate
(
accuData
,
thread_buffer
[
I
]);
});
};
// This interface accumulates on both data values and indices and
// is called by Direct_ThreadWise reduction method at first-time reduction
__device__
static
void
Reduce2
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
,
int
&
accuIndex
,
int
indexStart
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
int
currIndex
=
I
+
indexStart
;
binop
::
calculate
(
accuData
,
thread_buffer
[
I
],
accuIndex
,
currIndex
);
});
};
// Set the elements in the per-thread buffer to a specific value
// cppcheck-suppress constParameter
__device__
static
void
set_buffer_value
(
BufferType
&
thread_buffer
,
compType
value
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
thread_buffer
(
I
)
=
value
;
});
};
// Execute unary operation on the per-thread buffer elements
template
<
typename
unary_op_type
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
thread_buffer
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
thread_buffer
(
I
)
=
unary_op
(
thread_buffer
[
I
]);
});
};
};
template
<
typename
BufferType
,
typename
IdxBufferType
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
ThreadReduceWithIndicesInput
{
using
compType
=
typename
opReduce
::
dataType
;
static_assert
(
BufferType
::
IsStaticBuffer
(),
"Thread-wise reduction needs use StaticBuffer!"
);
static_assert
(
IdxBufferType
::
IsStaticBuffer
(),
"Thread-wise reduction needs use StaticBuffer for indices!"
);
static_assert
(
std
::
is_same
<
typename
BufferType
::
type
,
compType
>::
value
,
"Data type of StaticBuffer for Thread-wise reduction should be same as the compType!"
);
static_assert
(
std
::
is_same
<
typename
IdxBufferType
::
type
,
index_t
>::
value
,
"Indices type of StaticBuffer for Thread-wise reduction should be index_t!"
);
static_assert
(
BufferType
::
Size
()
==
IdxBufferType
::
Size
(),
"StaticBuffers for data and indices should have the same sizes!"
);
static
constexpr
index_t
ThreadBufferLen
=
BufferType
::
Size
();
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
// This interface accumulates on both data values and indices and
// is called by Direct_ThreadWise reduction method at second-time reduction
__device__
static
void
Reduce
(
const
BufferType
&
thread_buffer
,
const
IdxBufferType
&
thread_indices_buffer
,
compType
&
accuData
,
int
&
accuIndex
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
binop
::
calculate
(
accuData
,
thread_buffer
[
I
],
accuIndex
,
thread_indices_buffer
[
I
]);
});
};
// Set the elements in the per-thread buffer to a specific value
// cppcheck-suppress constParameter
__device__
static
void
set_buffer_value
(
BufferType
&
thread_buffer
,
compType
value
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
thread_buffer
(
I
)
=
value
;
});
};
// Execute unary operation on the per-thread buffer elements
template
<
typename
unary_op_type
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
thread_buffer
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
thread_buffer
(
I
)
=
unary_op
(
thread_buffer
[
I
]);
});
};
};
};
// end of namespace ck
#endif
composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp
0 → 100644
View file @
38a90b6e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_WARPWISE_HPP
#define CK_REDUCTION_FUNCTIONS_WARPWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace
ck
{
template
<
typename
BufferType
,
index_t
BlockSize
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
WarpReduce
{
using
compType
=
typename
opReduce
::
dataType
;
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
static_assert
(
BufferType
::
IsStaticBuffer
(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer!"
);
static_assert
(
std
::
is_same
<
typename
BufferType
::
type
,
compType
>::
value
,
"Data type of per-thread StaticBuffer for WarpWise reduction should be same as "
"the compType!"
);
static
constexpr
index_t
ThreadBufferLen
=
BufferType
::
Size
();
static
constexpr
bool
have_builtin_shuffle
=
std
::
is_same
<
compType
,
float
>::
value
||
std
::
is_same
<
compType
,
double
>::
value
;
// This interface does not accumulate on indices
__device__
static
void
Reduce
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
)
{
if
constexpr
(
have_builtin_shuffle
)
ReduceImpl1
(
thread_buffer
,
accuData
);
else
ReduceImpl2
(
thread_buffer
,
accuData
);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__
static
void
ReduceImpl1
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
]);
});
// synchronize among all threads in this warp
__all
(
1
);
for
(
index_t
stride
=
warpSize
/
2
;
stride
>
0
;
stride
/=
2
)
{
compType
tmpVal
=
__shfl_down
(
lAccuData
,
stride
,
warpSize
);
binop
::
calculate
(
lAccuData
,
tmpVal
);
__all
(
1
);
}
binop
::
calculate
(
accuData
,
lAccuData
);
};
// This interface implementation does not use HIP built-in device shuffling functions
// since for fp16, built-in shuffling functions is not provided by HIP
__device__
static
void
ReduceImpl2
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
]);
});
__syncthreads
();
index_t
thread_id
=
get_thread_local_1d_id
();
index_t
warpId
=
thread_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_id
%
warpSize
;
__shared__
compType
shuffle_buffer
[
BlockSize
];
compType
*
myBuffer
=
&
shuffle_buffer
[
warpId
*
warpSize
];
myBuffer
[
thread_inwarp_id
]
=
lAccuData
;
__syncthreads
();
for
(
index_t
stride
=
warpSize
/
2
;
stride
>
0
;
stride
/=
2
)
{
if
(
thread_inwarp_id
<
stride
)
{
compType
currVal1
=
myBuffer
[
thread_inwarp_id
];
compType
currVal2
=
myBuffer
[
thread_inwarp_id
+
stride
];
binop
::
calculate
(
currVal1
,
currVal2
);
myBuffer
[
thread_inwarp_id
]
=
currVal1
;
}
__syncthreads
();
}
if
(
thread_inwarp_id
==
0
)
binop
::
calculate
(
accuData
,
myBuffer
[
0
]);
};
// This interface accumulates on both data values and indices and is called by Direct_WarpWise
// reduction method at first-time reduction
__device__
static
void
Reduce2
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
,
int
&
accuIndex
,
int
indexStart
)
{
if
constexpr
(
have_builtin_shuffle
)
Reduce2Impl1
(
thread_buffer
,
accuData
,
accuIndex
,
indexStart
);
else
Reduce2Impl2
(
thread_buffer
,
accuData
,
accuIndex
,
indexStart
);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__
static
void
Reduce2Impl1
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
,
int
&
accuIndex
,
int
indexStart
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
index_t
thread_inwarp_id
=
get_thread_local_1d_id
()
%
warpSize
;
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
int
currIndex
=
thread_inwarp_id
*
ThreadBufferLen
+
I
+
indexStart
;
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
],
lAccuIndex
,
currIndex
);
});
// synchronize among all threads in this warp
__all
(
1
);
for
(
index_t
stride
=
1
;
stride
<
warpSize
;
stride
*=
2
)
{
compType
tmpVal
=
__shfl_down
(
lAccuData
,
stride
,
warpSize
);
int
tmpIndex
=
__shfl_down
(
lAccuIndex
,
stride
,
warpSize
);
binop
::
calculate
(
lAccuData
,
tmpVal
,
lAccuIndex
,
tmpIndex
);
__all
(
1
);
}
if
(
thread_inwarp_id
==
0
)
binop
::
calculate
(
accuData
,
lAccuData
,
accuIndex
,
lAccuIndex
);
};
// This interface implementation does not use HIP built-in device shuffling functions since for
// fp16, built-in shuffling functions is not provided by HIP
__device__
static
void
Reduce2Impl2
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
,
int
&
accuIndex
,
int
indexStart
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
index_t
thread_id
=
get_thread_local_1d_id
();
index_t
warpId
=
thread_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_id
%
warpSize
;
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
int
currIndex
=
thread_inwarp_id
*
ThreadBufferLen
+
I
+
indexStart
;
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
],
lAccuIndex
,
currIndex
);
});
__shared__
compType
shuffle_data_buffer
[
BlockSize
];
__shared__
int
shuffle_indices_buffer
[
BlockSize
];
compType
*
myDataBuffer
=
&
shuffle_data_buffer
[
warpId
*
warpSize
];
int
*
myIndicesBuffer
=
&
shuffle_indices_buffer
[
warpId
*
warpSize
];
myDataBuffer
[
thread_inwarp_id
]
=
lAccuData
;
myIndicesBuffer
[
thread_inwarp_id
]
=
lAccuIndex
;
__syncthreads
();
for
(
index_t
stride
=
1
;
stride
<
warpSize
;
stride
*=
2
)
{
compType
currVal1
=
myDataBuffer
[
thread_inwarp_id
];
compType
currVal2
=
myDataBuffer
[
thread_inwarp_id
+
stride
];
int
currIndex1
=
myIndicesBuffer
[
thread_inwarp_id
];
int
currIndex2
=
myIndicesBuffer
[
thread_inwarp_id
+
stride
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
myDataBuffer
[
thread_inwarp_id
]
=
currVal1
;
myIndicesBuffer
[
thread_inwarp_id
]
=
currIndex1
;
__syncthreads
();
}
if
(
thread_inwarp_id
==
0
)
binop
::
calculate
(
accuData
,
myDataBuffer
[
0
],
accuIndex
,
myIndicesBuffer
[
0
]);
};
// cppcheck-suppress constParameter
__device__
static
void
set_buffer_value
(
BufferType
&
thread_buffer
,
compType
value
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
thread_buffer
(
I
)
=
value
;
});
__all
(
1
);
};
// Execute unary operation on the per-thread buffer elements
template
<
typename
unary_op_type
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
thread_buffer
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
thread_buffer
(
I
)
=
unary_op
(
thread_buffer
[
I
]);
});
__all
(
1
);
};
};
template
<
typename
BufferType
,
typename
IdxBufferType
,
index_t
BlockSize
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
WarpReduceWithIndicesInput
{
using
compType
=
typename
opReduce
::
dataType
;
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
static_assert
(
BufferType
::
IsStaticBuffer
(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer!"
);
static_assert
(
IdxBufferType
::
IsStaticBuffer
(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer for indices!"
);
static_assert
(
std
::
is_same
<
typename
BufferType
::
type
,
compType
>::
value
,
"Data type of per-thread StaticBuffer for WarpWise reduction should be same as "
"the compType!"
);
static_assert
(
std
::
is_same
<
typename
IdxBufferType
::
type
,
index_t
>::
value
,
"Indices type per-thread of StaticBuffer for WarpWise reduction should be index_t!"
);
static_assert
(
BufferType
::
Size
()
==
IdxBufferType
::
Size
(),
"StaticBuffers for data and indices should have the same sizes!"
);
static
constexpr
index_t
ThreadBufferLen
=
BufferType
::
Size
();
static
constexpr
bool
have_builtin_shuffle
=
std
::
is_same
<
compType
,
float
>::
value
||
std
::
is_same
<
compType
,
double
>::
value
;
// This interface accumulates on both data values and indices and is called by Direct_WarpWise
// reduction method at second-time reduction
__device__
static
void
Reduce
(
const
BufferType
&
thread_buffer
,
const
IdxBufferType
&
thread_indices_buffer
,
compType
&
accuData
,
int
&
accuIndex
)
{
if
constexpr
(
have_builtin_shuffle
)
ReduceImpl1
(
thread_buffer
,
thread_indices_buffer
,
accuData
,
accuIndex
);
else
ReduceImpl2
(
thread_buffer
,
thread_indices_buffer
,
accuData
,
accuIndex
);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__
static
void
ReduceImpl1
(
const
BufferType
&
thread_buffer
,
const
IdxBufferType
&
thread_indices_buffer
,
compType
&
accuData
,
int
&
accuIndex
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
],
lAccuIndex
,
thread_indices_buffer
[
I
]);
});
// synchronize among all threads in this warp
__all
(
1
);
for
(
index_t
stride
=
1
;
stride
<
warpSize
;
stride
*=
2
)
{
compType
tmpVal
=
__shfl_down
(
lAccuData
,
stride
,
warpSize
);
int
tmpIndex
=
__shfl_down
(
lAccuIndex
,
stride
,
warpSize
);
binop
::
calculate
(
lAccuData
,
tmpVal
,
lAccuIndex
,
tmpIndex
);
__all
(
1
);
}
binop
::
calculate
(
accuData
,
lAccuData
,
accuIndex
,
lAccuIndex
);
};
// This interface implementation does not use HIP built-in device shuffling functions
// since for fp16, built-in shuffling functions is not provided by HIP
__device__
static
void
ReduceImpl2
(
const
BufferType
&
thread_buffer
,
const
IdxBufferType
&
thread_indices_buffer
,
compType
&
accuData
,
int
&
accuIndex
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
index_t
thread_id
=
get_thread_local_1d_id
();
index_t
warpId
=
thread_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_id
%
warpSize
;
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
],
lAccuIndex
,
thread_indices_buffer
[
I
]);
});
__shared__
compType
shuffle_data_buffer
[
BlockSize
];
__shared__
int
shuffle_indices_buffer
[
BlockSize
];
compType
*
myDataBuffer
=
&
shuffle_data_buffer
[
warpId
*
warpSize
];
int
*
myIndicesBuffer
=
&
shuffle_indices_buffer
[
warpId
*
warpSize
];
myDataBuffer
[
thread_inwarp_id
]
=
lAccuData
;
myIndicesBuffer
[
thread_inwarp_id
]
=
lAccuIndex
;
__syncthreads
();
for
(
index_t
stride
=
1
;
stride
<
warpSize
;
stride
*=
2
)
{
compType
currVal1
=
myDataBuffer
[
thread_inwarp_id
];
compType
currVal2
=
myDataBuffer
[
thread_inwarp_id
+
stride
];
int
currIndex1
=
myIndicesBuffer
[
thread_inwarp_id
];
int
currIndex2
=
myIndicesBuffer
[
thread_inwarp_id
+
stride
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
myDataBuffer
[
thread_inwarp_id
]
=
currVal1
;
myIndicesBuffer
[
thread_inwarp_id
]
=
currIndex1
;
__syncthreads
();
}
if
(
thread_inwarp_id
==
0
)
binop
::
calculate
(
accuData
,
myDataBuffer
[
0
],
accuIndex
,
myIndicesBuffer
[
0
]);
};
// cppcheck-suppress constParameter
__device__
static
void
set_buffer_value
(
BufferType
&
thread_buffer
,
compType
value
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
thread_buffer
(
I
)
=
value
;
});
__all
(
1
);
};
// Execute unary operation on the per-thread buffer elements
template
<
typename
unary_op_type
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
thread_buffer
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
thread_buffer
(
I
)
=
unary_op
(
thread_buffer
[
I
]);
});
__all
(
1
);
};
};
};
// end of namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
38a90b6e
...
...
@@ -397,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v2
"wrong! SrcDesc need to known at compile-time"
);
}
__device__
void
Set
Dst
SliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
Set
Src
SliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
...
...
@@ -713,9 +713,6 @@ struct ThreadwiseTensorSliceTransfer_v3
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
// TODO: fix this
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
"wrong! current implementation assume SrcData and DstData are same type"
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
@@ -985,7 +982,8 @@ struct ThreadwiseTensorSliceTransfer_v3
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
buffer_
[
Number
<
buffer_offset
>
{}];
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}(
buffer_
[
Number
<
buffer_offset
>
{}]);
});
using
dst_vector_t
=
typename
decltype
(
dst_tmp_vector
)
::
type
;
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
38a90b6e
...
...
@@ -44,15 +44,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -71,15 +66,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -98,15 +88,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x4f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x4f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -125,15 +110,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x1f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -153,15 +133,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_4x4x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_4x4x1f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -180,15 +155,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x4f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -207,15 +177,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x8f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x8f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -234,15 +199,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x16f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -261,15 +221,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x4f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -288,15 +243,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_4x4x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_4x4x4f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
...
...
@@ -732,7 +682,7 @@ struct XdlopsGemm
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
}
template
<
index_t
c_offset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
...
...
@@ -740,8 +690,7 @@ struct XdlopsGemm
"base base_type must be float, half, ushort!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
});
}
...
...
@@ -819,8 +768,9 @@ struct XdlopsGemm
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
static
constexpr
auto
KPerXdlops
=
mfma
.
GetKPerXdlops
();
static
constexpr
auto
KPerThread
=
mfma
.
GetKPerThread
();
static
constexpr
auto
KPerXdlops
=
mfma
.
GetKPerXdlops
();
static
constexpr
auto
K1PerXdlops
=
mfma
.
GetKPerThread
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
{
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
38a90b6e
...
...
@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x1f32
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
1
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
1
>
{}],
1
,
1
,
0
);
}
};
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x2f32
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
,
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x4f32
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_16x16x4f32
<
16
,
16
,
COffset
>
template
<
>
struct
intrin_mfma_f32_16x16x4f32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x1f32
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_16x16x1f32
<
16
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_16x16x1f32
<
16
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x1f32
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
}
};
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
1
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
1
>
{}],
4
,
1
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x4f16
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
1
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
1
>
{}],
1
,
1
,
0
);
}
};
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x4f16
<
32
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x4f16
<
32
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8f16
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_32x32x8f16
<
32
,
32
,
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x8f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x16f16
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_16x16x16f16
<
16
,
16
,
COffset
>
template
<
>
struct
intrin_mfma_f32_16x16x16f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x4f16
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_16x16x4f16
<
16
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_16x16x4f16
<
16
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x4f16
;
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x4f16
<
4
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_4x4x4f16
<
4
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
}
};
template
<
index_t
COffset
>
struct
intrin_mfma_f32_4x4x4f16
<
8
,
64
,
COffset
>
template
<
>
struct
intrin_mfma_f32_4x4x4f16
<
8
,
64
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
1
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
1
>
{}],
4
,
1
,
0
);
}
};
...
...
@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
...
...
composable_kernel/include/utility/config.hpp
View file @
38a90b6e
...
...
@@ -90,8 +90,8 @@
#endif
// pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
0
// merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment