Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5ce317cb
Commit
5ce317cb
authored
Oct 07, 2021
by
Jing Zhang
Browse files
add fwd_driver_offline_nchwc
parents
71bc108d
b2dc55f8
Changes
67
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3391 additions
and
590 deletions
+3391
-590
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
...ward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
+90
-13
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
...clude/tensor_description/multi_index_transform_helper.hpp
+1
-1
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+41
-41
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+2
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+183
-261
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
...sor_operation/gridwise_generic_2d_reduction_blockwise.hpp
+625
-0
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
...ation/gridwise_generic_2d_reduction_direct_threadwise.hpp
+503
-0
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
...eration/gridwise_generic_2d_reduction_direct_warpwise.hpp
+544
-0
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
...or_operation/gridwise_generic_2d_reduction_multiblock.hpp
+376
-0
composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp
...nclude/tensor_operation/reduction_functions_blockwise.hpp
+271
-0
composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp
...clude/tensor_operation/reduction_functions_threadwise.hpp
+141
-0
composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp
...include/tensor_operation/reduction_functions_warpwise.hpp
+371
-0
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+2
-4
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+25
-75
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+74
-183
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+2
-2
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+17
-10
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+4
-0
composable_kernel/include/utility/reduction_common.hpp
composable_kernel/include/utility/reduction_common.hpp
+53
-0
composable_kernel/include/utility/reduction_enums.hpp
composable_kernel/include/utility/reduction_enums.hpp
+66
-0
No files found.
composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
View file @
5ce317cb
...
@@ -21,8 +21,8 @@ template <typename... Wei,
...
@@ -21,8 +21,8 @@ template <typename... Wei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
,
typename
InRightPads
,
index_t
IYTilda
Value
,
typename
IYTilda
,
index_t
IXTilda
Value
,
typename
IXTilda
,
index_t
GemmK1Value
>
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
...
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
IYTildaValue
>
,
IYTilda
i_ytilda
,
Number
<
IXTildaValue
>
,
IXTilda
i_xtilda
,
Number
<
GemmK1Value
>
)
Number
<
GemmK1Value
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -43,8 +43,6 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -43,8 +43,6 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
constexpr
auto
IYTilda
=
Number
<
IYTildaValue
>
{};
constexpr
auto
IXTilda
=
Number
<
IXTildaValue
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
...
@@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const
auto
WTildaSlice
=
IWTildaSliceEnd
-
IWTildaSliceBegin
;
const
auto
WTildaSlice
=
IWTildaSliceEnd
-
IWTildaSliceBegin
;
// GemmK is different for each GEMM
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
IYT
ilda
,
YTilda
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_yt
ilda
,
YTilda
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
IXT
ilda
,
XTilda
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xt
ilda
,
XTilda
);
const
auto
K1
=
GemmK1
;
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
const
auto
K0
=
K
/
K1
;
...
@@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYT
ilda
),
make_freeze_transform
(
i_yt
ilda
),
make_freeze_transform
(
IXT
ilda
),
make_freeze_transform
(
i_xt
ilda
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
IYT
ilda
),
make_freeze_transform
(
i_yt
ilda
),
make_slice_transform
(
HTilda
,
IHTildaSliceBegin
,
HTildaSlice
),
make_slice_transform
(
HTilda
,
IHTildaSliceBegin
,
HTildaSlice
),
make_freeze_transform
(
IXT
ilda
),
make_freeze_transform
(
i_xt
ilda
),
make_slice_transform
(
WTilda
,
IWTildaSliceBegin
,
WTildaSlice
),
make_slice_transform
(
WTilda
,
IWTildaSliceBegin
,
WTildaSlice
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
...
@@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
in_gemmm_gemmn_grid_desc
);
in_gemmm_gemmn_grid_desc
);
}
}
// A: out
// B: wei
// C: in
// Number of GEMMs = 1
// GemmM = N * Ho * Wo
// GemmN = C
// GemmK = K
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1
(
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
/* wei_k_y_x_c_grid_desc */
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
ConvStrides
&
conv_strides
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
View file @
5ce317cb
...
@@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform(
...
@@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform(
return
LeftPad
<
LowLength
,
LeftPadLength
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
return
LeftPad
<
LowLength
,
LeftPadLength
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
}
}
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
>
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_right_pad_transform
(
__host__
__device__
constexpr
auto
make_right_pad_transform
(
const
LowLength
&
low_length
,
const
LowLength
&
low_length
,
const
RightPadLength
&
right_pad
,
const
RightPadLength
&
right_pad
,
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
5ce317cb
...
@@ -10,6 +10,7 @@ namespace ck {
...
@@ -10,6 +10,7 @@ namespace ck {
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
MPerXDL
,
...
@@ -30,13 +31,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -30,13 +31,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
KPerBlock
=
K0
;
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferV2
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
16
>
,
MRepeat
*
NRepeat
,
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
__device__
static
auto
GetWaveIdx
()
{
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
@@ -162,7 +167,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -162,7 +167,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
AK0MK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
PerBlock
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
K
0
>
{}),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{})),
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{})),
make_pass_through_transform
(
Number
<
K1
>
{})),
make_pass_through_transform
(
Number
<
K1
>
{})),
...
@@ -174,7 +179,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -174,7 +179,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
BK0NK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
PerBlock
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
K
0
>
{}),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{})),
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{})),
make_pass_through_transform
(
Number
<
K1
>
{})),
make_pass_through_transform
(
Number
<
K1
>
{})),
...
@@ -195,48 +200,43 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -195,48 +200,43 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
vector_type
<
FloatAB
,
K1
>
a_thread_vec
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
/
xdlops_gemm
.
KPerThread
>
{}([
&
](
auto
k0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
make_tuple
(
k
0
,
I
0
,
I0
,
I0
,
I0
),
make_tuple
(
I
0
,
m
0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
// read B
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1_block_desc
,
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1_block_desc
,
make_tuple
(
k
0
,
I
0
,
I0
,
I0
,
I0
),
make_tuple
(
I
0
,
n
0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
KPerThread
>::
type
;
static_for
<
0
,
K0
,
xdlops_gemm
.
K0PerXdlops
>
{}([
&
](
auto
k0
)
{
vector_type
<
FloatAB
,
K1
>
a_thread_vec
;
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
m0
,
0
,
0
,
i
))
>
{}];
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
k0
,
0
,
0
,
0
,
i
))
>
{}];
});
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
n
0
,
0
,
0
,
i
))
>
{}];
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
k
0
,
0
,
0
,
0
,
i
))
>
{}];
});
});
constexpr
index_t
c_offset
=
using
mfma_input_type
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
));
xdlops_gemm
.
template
Run
<
c_offset
>(
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
);
c_thread_buf
.
GetVector
(
Number
<
c_offset
>{})
);
});
});
});
});
});
});
...
@@ -244,35 +244,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -244,35 +244,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
private:
private:
// A[K, M]
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
a_thread_desc_
=
make_tuple
(
I1
,
Number
<
MRepeat
>
{}
,
I1
,
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0
>
{},
I1
,
I1
,
I1
,
Number
<
K1
>
{}));
// B[K, N]
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
b_thread_desc_
=
make_tuple
(
I1
,
Number
<
NRepeat
>
{}
,
I1
,
I1
,
Number
<
K1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0
>
{},
I1
,
I1
,
I1
,
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}
,
Number
<
xdlops_gemm
.
GetNumXdlops
()
>
{}
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_k0_m0_m1_m2_k1_block_desc
),
decltype
(
a_k0_m0_m1_m2_k1_block_desc
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
MRepeat
,
1
,
1
,
K1
>
,
Sequence
<
K0
,
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
K1
,
K1
,
1
>
;
K
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_k0_n0_n1_n2_k1_block_desc
),
decltype
(
b_k0_n0_n1_n2_k1_block_desc
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
NRepeat
,
1
,
1
,
K1
>
,
Sequence
<
K0
,
1
,
1
,
1
,
K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
K1
,
K1
,
1
>
;
K
1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
5ce317cb
...
@@ -485,6 +485,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -485,6 +485,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
true
>
true
>
c_thread_buf
;
c_thread_buf
;
#if 0
// initialize output thread tensor
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc,
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k1_n_h2_w2_thread_gemm_desc),
decltype(c_k1_n_h2_w2_thread_gemm_desc),
...
@@ -493,6 +494,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -493,6 +494,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
make_tuple(I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_thread_buf,
FloatAcc{0});
FloatAcc{0});
#endif
constexpr
auto
b_thread_slice_copy_step
=
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
0
,
E1PerBlock
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
make_multi_index
(
0
,
E1PerBlock
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
5ce317cb
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp
0 → 100644
View file @
5ce317cb
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_threadwise.hpp
0 → 100644
View file @
5ce317cb
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_direct_warpwise.hpp
0 → 100644
View file @
5ce317cb
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_multiblock.hpp
0 → 100644
View file @
5ce317cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_GENERIC_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_blockwise.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
srcDataType
,
typename
dstDataType
,
// not used together with the beta input
typename
compType
,
typename
src2dDescType
,
typename
dst1dDescType
,
ReduceTensorOp_t
op
,
NanPropagation_t
nanPropaOpt
,
ReduceTensorIndices_t
reduceIndicesOpt
,
index_t
GredAccessesPerThreadInBlock
>
struct
GridwiseReduction_xy_to_x_multiblock
{
using
opReduce
=
typename
reduce_binary_operator
<
compType
,
op
>::
opType
;
using
preUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
true
,
false
>::
preUnaryOp
;
using
posUnaryOpType
=
typename
reduce_unary_operator
<
compType
,
op
,
true
,
false
>::
posUnaryOp
;
static
constexpr
auto
buffer2dDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
GredAccessesPerThreadInBlock
>
{},
Number
<
BlockSize
>
{}));
using
blockwise_reduce
=
BlockwiseReduction_2d_block_buffer
<
decltype
(
buffer2dDesc
),
true
,
opReduce
,
nanPropaOpt
>
;
static
constexpr
index_t
BlockBufferSize
=
buffer2dDesc
.
GetElementSize
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
int
RunId
>
__device__
static
void
Run
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
int
BlkGroupSize
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
srcDataType
*
const
__restrict__
ws_values_global
,
int
*
const
__restrict__
ws_indices_global
);
template
<
>
__device__
static
void
Run
<
1
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
int
BlkGroupSize
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
srcDataType
*
const
__restrict__
ws_values_global
,
int
*
const
__restrict__
ws_indices_global
)
{
(
void
)
ws_indices_global
;
(
void
)
alpha
;
// unused
(
void
)
beta
;
// unused
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
// LDS
__shared__
compType
p_in_block_buffer
[
BlockBufferSize
];
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
workspace_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
);
auto
in_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_buffer
,
BlockBufferSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
BlkGroupSize
;
const
index_t
block_local_id
=
block_global_id
%
BlkGroupSize
;
const
index_t
reduceSizePerBlock
=
(((
toReduceLength
+
BlkGroupSize
-
1
)
/
BlkGroupSize
+
BlockBufferSize
-
1
)
/
BlockBufferSize
)
*
BlockBufferSize
;
constexpr
auto
in_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
BlockSize
*
GredAccessesPerThreadInBlock
>
{}));
using
ThreadSliceLengths
=
Sequence
<
1
,
GredAccessesPerThreadInBlock
>
;
using
ThreadClusterLengths
=
Sequence
<
1
,
BlockSize
>
;
auto
blockwise_src_load
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
BlockBufferSize
>
,
ThreadSliceLengths
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
srcDataType
,
compType
,
src2dDescType
,
decltype
(
in_block_desc
),
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
,
true
>
(
src2dDesc
,
make_multi_index
(
blkgroup_id
,
block_local_id
*
reduceSizePerBlock
),
in_block_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
in_block_copy_step
=
make_multi_index
(
0
,
BlockBufferSize
);
const
index_t
toReduceBlocks
=
(
reduceSizePerBlock
+
BlockSize
-
1
)
/
BlockSize
;
for
(
index_t
reducedBlocks
=
0
;
reducedBlocks
<
toReduceBlocks
;
reducedBlocks
+=
GredAccessesPerThreadInBlock
)
{
blockwise_src_load
.
RunRead
(
src2dDesc
,
src_global_buf
);
blockwise_src_load
.
RunWrite
(
in_block_desc
,
in_block_buf
);
__syncthreads
();
// do element-wise pre-reduction operation
blockwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_block_buf
);
index_t
BlocksInOneOp
=
(
reducedBlocks
<
toReduceBlocks
-
GredAccessesPerThreadInBlock
)
?
GredAccessesPerThreadInBlock
:
toReduceBlocks
-
reducedBlocks
;
blockwise_reduce
::
Reduce
(
in_block_buf
,
BlocksInOneOp
,
accuValue_buf
(
I0
));
blockwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_block_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
const
auto
workspace_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
));
// The first thread in the block stores the reduced result to the global location
// representing the block
if
(
thread_local_id
==
0
)
{
auto
threadwise_workspace_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
compType
,
srcDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
workspace_desc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
workspace_desc
,
make_multi_index
(
block_global_id
));
threadwise_workspace_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuValue_buf
,
workspace_desc
,
workspace_global_buf
);
}
};
template
<
>
__device__
static
void
Run
<
2
>
(
const
src2dDescType
&
src2dDesc
,
const
dst1dDescType
&
dst1dDesc
,
int
origReduceLen
,
int
BlkGroupSize
,
srcDataType
alpha
,
const
srcDataType
*
const
__restrict__
p_src_global
,
dstDataType
beta
,
srcDataType
*
const
__restrict__
ws_values_global
,
int
*
const
__restrict__
ws_indices_global
)
{
(
void
)
alpha
;
// unused
(
void
)
beta
;
// unused
const
auto
zeroVal
=
opReduce
::
GetReductionZeroVal
();
// LDS
__shared__
compType
p_in_block_values_buffer
[
BlockBufferSize
];
__shared__
int
p_in_block_indices_buffer
[
BlockBufferSize
];
const
auto
src_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_src_global
,
src2dDesc
.
GetElementSpaceSize
(),
type_convert
<
srcDataType
>
{}(
zeroVal
));
auto
workspace_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_values_global
,
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
);
auto
workspace_global_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
ws_indices_global
,
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
);
auto
in_block_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_values_buffer
,
BlockBufferSize
);
auto
in_block_idx_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_in_block_indices_buffer
,
BlockBufferSize
);
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
compType
,
1
,
true
>
accuValue_buf
;
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
int
,
1
,
true
>
accuIndex_buf
;
accuValue_buf
(
I0
)
=
zeroVal
;
accuIndex_buf
(
I0
)
=
0
;
const
auto
toReduceLength
=
src2dDesc
.
GetLength
(
Number
<
1
>
{});
const
int
divider
=
origReduceLen
;
const
preUnaryOpType
preUnaryOp
(
divider
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
BlkGroupSize
;
const
index_t
block_local_id
=
block_global_id
%
BlkGroupSize
;
const
index_t
reduceSizePerBlock
=
(((
toReduceLength
+
BlkGroupSize
-
1
)
/
BlkGroupSize
+
BlockBufferSize
-
1
)
/
BlockBufferSize
)
*
BlockBufferSize
;
constexpr
auto
in_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
BlockSize
*
GredAccessesPerThreadInBlock
>
{}));
using
ThreadSliceLengths
=
Sequence
<
1
,
GredAccessesPerThreadInBlock
>
;
using
ThreadClusterLengths
=
Sequence
<
1
,
BlockSize
>
;
auto
blockwise_src_load
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
BlockBufferSize
>
,
ThreadSliceLengths
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
srcDataType
,
compType
,
src2dDescType
,
decltype
(
in_block_desc
),
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
false
,
true
>
(
src2dDesc
,
make_multi_index
(
blkgroup_id
,
block_local_id
*
reduceSizePerBlock
),
in_block_desc
,
make_multi_index
(
0
,
0
));
constexpr
auto
in_block_copy_step
=
make_multi_index
(
0
,
BlockBufferSize
);
const
index_t
toReduceBlocks
=
(
reduceSizePerBlock
+
BlockSize
-
1
)
/
BlockSize
;
int
indexOffset
=
block_local_id
*
reduceSizePerBlock
;
for
(
index_t
reducedBlocks
=
0
;
reducedBlocks
<
toReduceBlocks
;
reducedBlocks
+=
GredAccessesPerThreadInBlock
)
{
blockwise_reduce
::
init_buffer_indices
(
in_block_idx_buf
,
indexOffset
);
blockwise_src_load
.
RunRead
(
src2dDesc
,
src_global_buf
);
blockwise_src_load
.
RunWrite
(
in_block_desc
,
in_block_val_buf
);
__syncthreads
();
// unary operation before reducing, needed by AMAX; For MIN/MAX, nothing is actually
// done here
blockwise_reduce
::
operate_on_elements
(
preUnaryOp
,
in_block_val_buf
);
index_t
BlocksInOneOp
=
(
reducedBlocks
<
toReduceBlocks
-
GredAccessesPerThreadInBlock
)
?
GredAccessesPerThreadInBlock
:
toReduceBlocks
-
reducedBlocks
;
blockwise_reduce
::
Reduce2
(
in_block_val_buf
,
in_block_idx_buf
,
BlocksInOneOp
,
accuValue_buf
(
I0
),
accuIndex_buf
(
I0
));
indexOffset
+=
BlockBufferSize
;
blockwise_src_load
.
MoveSrcSliceWindow
(
src2dDesc
,
in_block_copy_step
);
}
constexpr
auto
ReducedDataDesc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}));
const
auto
workspace_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
dst1dDesc
.
GetLength
(
I0
)
*
BlkGroupSize
));
// The first thread in the block stores the reduced result to the global location
// representing the block
if
(
thread_local_id
==
0
)
{
auto
threadwise_workspace_val_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
compType
,
srcDataType
,
decltype
(
ReducedDataDesc
),
decltype
(
workspace_desc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
workspace_desc
,
make_multi_index
(
block_global_id
));
auto
threadwise_workspace_idx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int
,
int
,
decltype
(
ReducedDataDesc
),
decltype
(
workspace_desc
),
Sequence
<
1
>
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
(
workspace_desc
,
make_multi_index
(
block_global_id
));
threadwise_workspace_val_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuValue_buf
,
workspace_desc
,
workspace_global_val_buf
);
threadwise_workspace_idx_store
.
Run
(
ReducedDataDesc
,
make_tuple
(
I0
),
accuIndex_buf
,
workspace_desc
,
workspace_global_idx_buf
);
}
};
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/reduction_functions_blockwise.hpp
0 → 100644
View file @
5ce317cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace
ck
{
template
<
typename
buffer2dDescType
,
bool
blockIsOneRow
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
BlockwiseReduction_2d_block_buffer
{
using
compType
=
typename
opReduce
::
dataType
;
static
constexpr
auto
buffer2dDesc
=
buffer2dDescType
{};
static
constexpr
index_t
BlockSize
=
blockIsOneRow
?
buffer2dDesc
.
GetLength
(
Number
<
1
>
{})
:
buffer2dDesc
.
GetLength
(
Number
<
0
>
{});
static
constexpr
index_t
NumBlocks
=
blockIsOneRow
?
buffer2dDesc
.
GetLength
(
Number
<
0
>
{})
:
buffer2dDesc
.
GetLength
(
Number
<
1
>
{});
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
// This interface does not accumulate on indices
template
<
typename
BufferType
>
__device__
static
void
Reduce
(
BufferType
&
block_buffer
,
index_t
toReduceBlocks
,
compType
&
accuData
)
{
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
index_t
offset
;
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_local_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
otherDimInd
));
compType
opData
=
type_convert
<
compType
>
{}(
block_buffer
[
offset
]);
binop
::
calculate
(
lAccuData
,
opData
);
}
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
0
,
thread_local_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
0
));
block_buffer
(
offset
)
=
lAccuData
;
__syncthreads
();
for
(
index_t
indOffset
=
BlockSize
/
2
;
indOffset
>
0
;
indOffset
/=
2
)
{
if
(
thread_local_id
<
indOffset
)
{
index_t
offset1
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
0
,
thread_local_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
0
));
index_t
offset2
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
0
,
thread_local_id
+
indOffset
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
+
indOffset
,
0
));
compType
opData1
=
type_convert
<
compType
>
{}(
block_buffer
[
offset1
]);
compType
opData2
=
type_convert
<
compType
>
{}(
block_buffer
[
offset2
]);
binop
::
calculate
(
opData1
,
opData2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}(
opData1
);
}
__syncthreads
();
}
if
(
thread_local_id
==
0
)
{
compType
tmpVal
=
type_convert
<
compType
>
{}(
block_buffer
[
0
]);
binop
::
calculate
(
accuData
,
tmpVal
);
}
};
// This interface accumulates on both data values and indices
template
<
typename
BufferType
,
typename
IdxBufferType
>
__device__
static
void
Reduce2
(
BufferType
&
block_buffer
,
IdxBufferType
&
block_indices_buffer
,
index_t
toReduceBlocks
,
compType
&
accuData
,
int
&
accuIndex
)
{
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
if
constexpr
(
blockIsOneRow
)
{
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
for
(
index_t
indOffset
=
1
;
indOffset
<
BlockSize
;
indOffset
*=
2
)
{
if
(
thread_local_id
%
(
indOffset
*
2
)
==
0
)
{
index_t
offset1
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_local_id
));
index_t
offset2
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_local_id
+
indOffset
));
compType
currVal1
=
type_convert
<
compType
>
{}(
block_buffer
[
offset1
]);
compType
currVal2
=
type_convert
<
compType
>
{}(
block_buffer
[
offset2
]);
int
currIndex1
=
block_indices_buffer
[
offset1
];
int
currIndex2
=
block_indices_buffer
[
offset2
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}(
currVal1
);
block_indices_buffer
(
offset1
)
=
currIndex1
;
}
__syncthreads
();
}
}
if
(
thread_local_id
==
0
)
{
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
index_t
offset
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
0
));
compType
tmpVal
=
type_convert
<
compType
>
{}(
block_buffer
[
offset
]);
int
tmpIndex
=
block_indices_buffer
[
offset
];
binop
::
calculate
(
lAccuData
,
tmpVal
,
lAccuIndex
,
tmpIndex
);
}
binop
::
calculate
(
accuData
,
lAccuData
,
accuIndex
,
lAccuIndex
);
}
}
else
{
index_t
offset
;
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
toReduceBlocks
;
otherDimInd
++
)
{
offset
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
otherDimInd
));
compType
currVal
=
type_convert
<
compType
>
{}(
block_buffer
[
offset
]);
int
currIndex
=
block_indices_buffer
[
offset
];
binop
::
calculate
(
lAccuData
,
currVal
,
lAccuIndex
,
currIndex
);
}
offset
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
0
));
block_buffer
(
offset
)
=
lAccuData
;
block_indices_buffer
(
offset
)
=
lAccuIndex
;
__syncthreads
();
for
(
index_t
indOffset
=
1
;
indOffset
<
BlockSize
;
indOffset
*=
2
)
{
if
(
thread_local_id
%
(
indOffset
*
2
)
==
0
)
{
index_t
offset1
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
,
0
));
index_t
offset2
=
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_local_id
+
indOffset
,
0
));
compType
currVal1
=
type_convert
<
compType
>
{}(
block_buffer
[
offset1
]);
compType
currVal2
=
type_convert
<
compType
>
{}(
block_buffer
[
offset2
]);
int
currIndex1
=
block_indices_buffer
[
offset1
];
int
currIndex2
=
block_indices_buffer
[
offset2
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
block_buffer
(
offset1
)
=
type_convert
<
compType
>
{}(
currVal1
);
block_indices_buffer
(
offset1
)
=
currIndex1
;
}
__syncthreads
();
}
if
(
thread_local_id
==
0
)
{
compType
tmpVal
=
type_convert
<
compType
>
{}(
block_buffer
[
0
]);
int
tmpIndex
=
block_indices_buffer
[
0
];
binop
::
calculate
(
accuData
,
tmpVal
,
accuIndex
,
tmpIndex
);
}
}
};
template
<
typename
BufferType
>
__device__
static
void
set_buffer_value
(
BufferType
&
block_buffer
,
compType
value
)
{
index_t
thread_id
=
get_thread_local_1d_id
();
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
NumBlocks
;
otherDimInd
++
)
{
index_t
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_id
,
otherDimInd
));
block_buffer
(
offset
)
=
value
;
__syncthreads
();
}
};
// Initialize the block-wise indices buffer, the index for each element in the block-wise data
// buffer
// is calculated according to its position in the buffer and the global starting index
template
<
typename
IdxBufferType
>
__device__
static
void
init_buffer_indices
(
IdxBufferType
&
block_indices_buffer
,
int
indexStart
)
{
index_t
thread_id
=
get_thread_local_1d_id
();
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
NumBlocks
;
otherDimInd
++
)
{
index_t
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_id
,
otherDimInd
));
block_indices_buffer
(
offset
)
=
offset
+
indexStart
;
__syncthreads
();
}
};
// Execute unary operation on the block buffer elements
template
<
typename
unary_op_type
,
typename
BufferType
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
block_buffer
)
{
index_t
thread_id
=
get_thread_local_1d_id
();
for
(
index_t
otherDimInd
=
0
;
otherDimInd
<
NumBlocks
;
otherDimInd
++
)
{
index_t
offset
=
blockIsOneRow
?
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
otherDimInd
,
thread_id
))
:
buffer2dDesc
.
CalculateOffset
(
make_tuple
(
thread_id
,
otherDimInd
));
block_buffer
(
offset
)
=
unary_op
(
block_buffer
[
offset
]);
__syncthreads
();
}
};
};
};
// end of namespace ck
#endif
composable_kernel/include/tensor_operation/reduction_functions_threadwise.hpp
0 → 100644
View file @
5ce317cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace
ck
{
template
<
typename
BufferType
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
ThreadReduce
{
using
compType
=
typename
opReduce
::
dataType
;
static_assert
(
BufferType
::
IsStaticBuffer
(),
"Thread-wise reduction needs use StaticBuffer!"
);
static_assert
(
std
::
is_same
<
typename
BufferType
::
type
,
compType
>::
value
,
"Data type of StaticBuffer for Thread-wise reduction should be same as the compType!"
);
static
constexpr
index_t
ThreadBufferLen
=
BufferType
::
Size
();
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
// This interface does not accumulate on indices
__device__
static
void
Reduce
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
binop
::
calculate
(
accuData
,
thread_buffer
[
I
]);
});
};
// This interface accumulates on both data values and indices and
// is called by Direct_ThreadWise reduction method at first-time reduction
__device__
static
void
Reduce2
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
,
int
&
accuIndex
,
int
indexStart
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
int
currIndex
=
I
+
indexStart
;
binop
::
calculate
(
accuData
,
thread_buffer
[
I
],
accuIndex
,
currIndex
);
});
};
// Set the elements in the per-thread buffer to a specific value
// cppcheck-suppress constParameter
__device__
static
void
set_buffer_value
(
BufferType
&
thread_buffer
,
compType
value
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
thread_buffer
(
I
)
=
value
;
});
};
// Execute unary operation on the per-thread buffer elements
template
<
typename
unary_op_type
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
thread_buffer
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
thread_buffer
(
I
)
=
unary_op
(
thread_buffer
[
I
]);
});
};
};
template
<
typename
BufferType
,
typename
IdxBufferType
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
ThreadReduceWithIndicesInput
{
using
compType
=
typename
opReduce
::
dataType
;
static_assert
(
BufferType
::
IsStaticBuffer
(),
"Thread-wise reduction needs use StaticBuffer!"
);
static_assert
(
IdxBufferType
::
IsStaticBuffer
(),
"Thread-wise reduction needs use StaticBuffer for indices!"
);
static_assert
(
std
::
is_same
<
typename
BufferType
::
type
,
compType
>::
value
,
"Data type of StaticBuffer for Thread-wise reduction should be same as the compType!"
);
static_assert
(
std
::
is_same
<
typename
IdxBufferType
::
type
,
index_t
>::
value
,
"Indices type of StaticBuffer for Thread-wise reduction should be index_t!"
);
static_assert
(
BufferType
::
Size
()
==
IdxBufferType
::
Size
(),
"StaticBuffers for data and indices should have the same sizes!"
);
static
constexpr
index_t
ThreadBufferLen
=
BufferType
::
Size
();
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
// This interface accumulates on both data values and indices and
// is called by Direct_ThreadWise reduction method at second-time reduction
__device__
static
void
Reduce
(
const
BufferType
&
thread_buffer
,
const
IdxBufferType
&
thread_indices_buffer
,
compType
&
accuData
,
int
&
accuIndex
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
binop
::
calculate
(
accuData
,
thread_buffer
[
I
],
accuIndex
,
thread_indices_buffer
[
I
]);
});
};
// Set the elements in the per-thread buffer to a specific value
// cppcheck-suppress constParameter
__device__
static
void
set_buffer_value
(
BufferType
&
thread_buffer
,
compType
value
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
thread_buffer
(
I
)
=
value
;
});
};
// Execute unary operation on the per-thread buffer elements
template
<
typename
unary_op_type
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
thread_buffer
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
thread_buffer
(
I
)
=
unary_op
(
thread_buffer
[
I
]);
});
};
};
};
// end of namespace ck
#endif
composable_kernel/include/tensor_operation/reduction_functions_warpwise.hpp
0 → 100644
View file @
5ce317cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_WARPWISE_HPP
#define CK_REDUCTION_FUNCTIONS_WARPWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_binop.hpp"
namespace
ck
{
template
<
typename
BufferType
,
index_t
BlockSize
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
WarpReduce
{
using
compType
=
typename
opReduce
::
dataType
;
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
static_assert
(
BufferType
::
IsStaticBuffer
(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer!"
);
static_assert
(
std
::
is_same
<
typename
BufferType
::
type
,
compType
>::
value
,
"Data type of per-thread StaticBuffer for WarpWise reduction should be same as "
"the compType!"
);
static
constexpr
index_t
ThreadBufferLen
=
BufferType
::
Size
();
static
constexpr
bool
have_builtin_shuffle
=
std
::
is_same
<
compType
,
float
>::
value
||
std
::
is_same
<
compType
,
double
>::
value
;
// This interface does not accumulate on indices
__device__
static
void
Reduce
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
)
{
if
constexpr
(
have_builtin_shuffle
)
ReduceImpl1
(
thread_buffer
,
accuData
);
else
ReduceImpl2
(
thread_buffer
,
accuData
);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__
static
void
ReduceImpl1
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
]);
});
// synchronize among all threads in this warp
__all
(
1
);
for
(
index_t
stride
=
warpSize
/
2
;
stride
>
0
;
stride
/=
2
)
{
compType
tmpVal
=
__shfl_down
(
lAccuData
,
stride
,
warpSize
);
binop
::
calculate
(
lAccuData
,
tmpVal
);
__all
(
1
);
}
binop
::
calculate
(
accuData
,
lAccuData
);
};
// This interface implementation does not use HIP built-in device shuffling functions
// since for fp16, built-in shuffling functions is not provided by HIP
__device__
static
void
ReduceImpl2
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
]);
});
__syncthreads
();
index_t
thread_id
=
get_thread_local_1d_id
();
index_t
warpId
=
thread_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_id
%
warpSize
;
__shared__
compType
shuffle_buffer
[
BlockSize
];
compType
*
myBuffer
=
&
shuffle_buffer
[
warpId
*
warpSize
];
myBuffer
[
thread_inwarp_id
]
=
lAccuData
;
__syncthreads
();
for
(
index_t
stride
=
warpSize
/
2
;
stride
>
0
;
stride
/=
2
)
{
if
(
thread_inwarp_id
<
stride
)
{
compType
currVal1
=
myBuffer
[
thread_inwarp_id
];
compType
currVal2
=
myBuffer
[
thread_inwarp_id
+
stride
];
binop
::
calculate
(
currVal1
,
currVal2
);
myBuffer
[
thread_inwarp_id
]
=
currVal1
;
}
__syncthreads
();
}
if
(
thread_inwarp_id
==
0
)
binop
::
calculate
(
accuData
,
myBuffer
[
0
]);
};
// This interface accumulates on both data values and indices and is called by Direct_WarpWise
// reduction method at first-time reduction
__device__
static
void
Reduce2
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
,
int
&
accuIndex
,
int
indexStart
)
{
if
constexpr
(
have_builtin_shuffle
)
Reduce2Impl1
(
thread_buffer
,
accuData
,
accuIndex
,
indexStart
);
else
Reduce2Impl2
(
thread_buffer
,
accuData
,
accuIndex
,
indexStart
);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__
static
void
Reduce2Impl1
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
,
int
&
accuIndex
,
int
indexStart
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
index_t
thread_inwarp_id
=
get_thread_local_1d_id
()
%
warpSize
;
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
int
currIndex
=
thread_inwarp_id
*
ThreadBufferLen
+
I
+
indexStart
;
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
],
lAccuIndex
,
currIndex
);
});
// synchronize among all threads in this warp
__all
(
1
);
for
(
index_t
stride
=
1
;
stride
<
warpSize
;
stride
*=
2
)
{
compType
tmpVal
=
__shfl_down
(
lAccuData
,
stride
,
warpSize
);
int
tmpIndex
=
__shfl_down
(
lAccuIndex
,
stride
,
warpSize
);
binop
::
calculate
(
lAccuData
,
tmpVal
,
lAccuIndex
,
tmpIndex
);
__all
(
1
);
}
if
(
thread_inwarp_id
==
0
)
binop
::
calculate
(
accuData
,
lAccuData
,
accuIndex
,
lAccuIndex
);
};
// This interface implementation does not use HIP built-in device shuffling functions since for
// fp16, built-in shuffling functions is not provided by HIP
__device__
static
void
Reduce2Impl2
(
const
BufferType
&
thread_buffer
,
compType
&
accuData
,
int
&
accuIndex
,
int
indexStart
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
index_t
thread_id
=
get_thread_local_1d_id
();
index_t
warpId
=
thread_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_id
%
warpSize
;
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
int
currIndex
=
thread_inwarp_id
*
ThreadBufferLen
+
I
+
indexStart
;
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
],
lAccuIndex
,
currIndex
);
});
__shared__
compType
shuffle_data_buffer
[
BlockSize
];
__shared__
int
shuffle_indices_buffer
[
BlockSize
];
compType
*
myDataBuffer
=
&
shuffle_data_buffer
[
warpId
*
warpSize
];
int
*
myIndicesBuffer
=
&
shuffle_indices_buffer
[
warpId
*
warpSize
];
myDataBuffer
[
thread_inwarp_id
]
=
lAccuData
;
myIndicesBuffer
[
thread_inwarp_id
]
=
lAccuIndex
;
__syncthreads
();
for
(
index_t
stride
=
1
;
stride
<
warpSize
;
stride
*=
2
)
{
compType
currVal1
=
myDataBuffer
[
thread_inwarp_id
];
compType
currVal2
=
myDataBuffer
[
thread_inwarp_id
+
stride
];
int
currIndex1
=
myIndicesBuffer
[
thread_inwarp_id
];
int
currIndex2
=
myIndicesBuffer
[
thread_inwarp_id
+
stride
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
myDataBuffer
[
thread_inwarp_id
]
=
currVal1
;
myIndicesBuffer
[
thread_inwarp_id
]
=
currIndex1
;
__syncthreads
();
}
if
(
thread_inwarp_id
==
0
)
binop
::
calculate
(
accuData
,
myDataBuffer
[
0
],
accuIndex
,
myIndicesBuffer
[
0
]);
};
// cppcheck-suppress constParameter
__device__
static
void
set_buffer_value
(
BufferType
&
thread_buffer
,
compType
value
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
thread_buffer
(
I
)
=
value
;
});
__all
(
1
);
};
// Execute unary operation on the per-thread buffer elements
template
<
typename
unary_op_type
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
thread_buffer
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
thread_buffer
(
I
)
=
unary_op
(
thread_buffer
[
I
]);
});
__all
(
1
);
};
};
template
<
typename
BufferType
,
typename
IdxBufferType
,
index_t
BlockSize
,
typename
opReduce
,
NanPropagation_t
nanPropaOpt
>
struct
WarpReduceWithIndicesInput
{
using
compType
=
typename
opReduce
::
dataType
;
using
binop
=
detail
::
binop_with_nan_check
<
nanPropaOpt
,
opReduce
,
compType
>
;
static_assert
(
BufferType
::
IsStaticBuffer
(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer!"
);
static_assert
(
IdxBufferType
::
IsStaticBuffer
(),
"Per-thread buffer for WarpWise reduction should be StaticBuffer for indices!"
);
static_assert
(
std
::
is_same
<
typename
BufferType
::
type
,
compType
>::
value
,
"Data type of per-thread StaticBuffer for WarpWise reduction should be same as "
"the compType!"
);
static_assert
(
std
::
is_same
<
typename
IdxBufferType
::
type
,
index_t
>::
value
,
"Indices type per-thread of StaticBuffer for WarpWise reduction should be index_t!"
);
static_assert
(
BufferType
::
Size
()
==
IdxBufferType
::
Size
(),
"StaticBuffers for data and indices should have the same sizes!"
);
static
constexpr
index_t
ThreadBufferLen
=
BufferType
::
Size
();
static
constexpr
bool
have_builtin_shuffle
=
std
::
is_same
<
compType
,
float
>::
value
||
std
::
is_same
<
compType
,
double
>::
value
;
// This interface accumulates on both data values and indices and is called by Direct_WarpWise
// reduction method at second-time reduction
__device__
static
void
Reduce
(
const
BufferType
&
thread_buffer
,
const
IdxBufferType
&
thread_indices_buffer
,
compType
&
accuData
,
int
&
accuIndex
)
{
if
constexpr
(
have_builtin_shuffle
)
ReduceImpl1
(
thread_buffer
,
thread_indices_buffer
,
accuData
,
accuIndex
);
else
ReduceImpl2
(
thread_buffer
,
thread_indices_buffer
,
accuData
,
accuIndex
);
};
// This interface implementation uses HIP built-in device shuffling functions
__device__
static
void
ReduceImpl1
(
const
BufferType
&
thread_buffer
,
const
IdxBufferType
&
thread_indices_buffer
,
compType
&
accuData
,
int
&
accuIndex
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
],
lAccuIndex
,
thread_indices_buffer
[
I
]);
});
// synchronize among all threads in this warp
__all
(
1
);
for
(
index_t
stride
=
1
;
stride
<
warpSize
;
stride
*=
2
)
{
compType
tmpVal
=
__shfl_down
(
lAccuData
,
stride
,
warpSize
);
int
tmpIndex
=
__shfl_down
(
lAccuIndex
,
stride
,
warpSize
);
binop
::
calculate
(
lAccuData
,
tmpVal
,
lAccuIndex
,
tmpIndex
);
__all
(
1
);
}
binop
::
calculate
(
accuData
,
lAccuData
,
accuIndex
,
lAccuIndex
);
};
// This interface implementation does not use HIP built-in device shuffling functions
// since for fp16, built-in shuffling functions is not provided by HIP
__device__
static
void
ReduceImpl2
(
const
BufferType
&
thread_buffer
,
const
IdxBufferType
&
thread_indices_buffer
,
compType
&
accuData
,
int
&
accuIndex
)
{
compType
lAccuData
=
opReduce
::
GetReductionZeroVal
();
int
lAccuIndex
=
0
;
index_t
thread_id
=
get_thread_local_1d_id
();
index_t
warpId
=
thread_id
/
warpSize
;
index_t
thread_inwarp_id
=
thread_id
%
warpSize
;
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
binop
::
calculate
(
lAccuData
,
thread_buffer
[
I
],
lAccuIndex
,
thread_indices_buffer
[
I
]);
});
__shared__
compType
shuffle_data_buffer
[
BlockSize
];
__shared__
int
shuffle_indices_buffer
[
BlockSize
];
compType
*
myDataBuffer
=
&
shuffle_data_buffer
[
warpId
*
warpSize
];
int
*
myIndicesBuffer
=
&
shuffle_indices_buffer
[
warpId
*
warpSize
];
myDataBuffer
[
thread_inwarp_id
]
=
lAccuData
;
myIndicesBuffer
[
thread_inwarp_id
]
=
lAccuIndex
;
__syncthreads
();
for
(
index_t
stride
=
1
;
stride
<
warpSize
;
stride
*=
2
)
{
compType
currVal1
=
myDataBuffer
[
thread_inwarp_id
];
compType
currVal2
=
myDataBuffer
[
thread_inwarp_id
+
stride
];
int
currIndex1
=
myIndicesBuffer
[
thread_inwarp_id
];
int
currIndex2
=
myIndicesBuffer
[
thread_inwarp_id
+
stride
];
binop
::
calculate
(
currVal1
,
currVal2
,
currIndex1
,
currIndex2
);
myDataBuffer
[
thread_inwarp_id
]
=
currVal1
;
myIndicesBuffer
[
thread_inwarp_id
]
=
currIndex1
;
__syncthreads
();
}
if
(
thread_inwarp_id
==
0
)
binop
::
calculate
(
accuData
,
myDataBuffer
[
0
],
accuIndex
,
myIndicesBuffer
[
0
]);
};
// cppcheck-suppress constParameter
__device__
static
void
set_buffer_value
(
BufferType
&
thread_buffer
,
compType
value
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}([
&
](
auto
I
)
{
thread_buffer
(
I
)
=
value
;
});
__all
(
1
);
};
// Execute unary operation on the per-thread buffer elements
template
<
typename
unary_op_type
>
__device__
static
void
operate_on_elements
(
unary_op_type
&
unary_op
,
BufferType
&
thread_buffer
)
{
static_for
<
0
,
ThreadBufferLen
,
1
>
{}(
[
&
](
auto
I
)
{
thread_buffer
(
I
)
=
unary_op
(
thread_buffer
[
I
]);
});
__all
(
1
);
};
};
};
// end of namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
5ce317cb
...
@@ -732,9 +732,6 @@ struct ThreadwiseTensorSliceTransfer_v3
...
@@ -732,9 +732,6 @@ struct ThreadwiseTensorSliceTransfer_v3
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
{
// TODO: fix this
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
"wrong! current implementation assume SrcData and DstData are same type"
);
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -1004,7 +1001,8 @@ struct ThreadwiseTensorSliceTransfer_v3
...
@@ -1004,7 +1001,8 @@ struct ThreadwiseTensorSliceTransfer_v3
constexpr
index_t
buffer_offset
=
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
buffer_desc_
.
CalculateOffset
(
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
buffer_
[
Number
<
buffer_offset
>
{}];
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}(
buffer_
[
Number
<
buffer_offset
>
{}]);
});
});
using
dst_vector_t
=
typename
decltype
(
dst_tmp_vector
)
::
type
;
using
dst_vector_t
=
typename
decltype
(
dst_tmp_vector
)
::
type
;
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
5ce317cb
...
@@ -44,15 +44,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
...
@@ -44,15 +44,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -71,15 +66,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
...
@@ -71,15 +66,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x2f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -98,15 +88,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
...
@@ -98,15 +88,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_16x16x4f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x4f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -125,15 +110,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
...
@@ -125,15 +110,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_16x16x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x1f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -153,15 +133,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
...
@@ -153,15 +133,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_4x4x1f32
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_4x4x1f32
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -180,15 +155,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16>
...
@@ -180,15 +155,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_32x32x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x4f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -207,15 +177,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
...
@@ -207,15 +177,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_32x32x8f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_32x32x8f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -234,15 +199,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
...
@@ -234,15 +199,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
true
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_16x16x16f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -261,15 +221,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16>
...
@@ -261,15 +221,10 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_16x16x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x4f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -288,15 +243,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
...
@@ -288,15 +243,10 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
index_t
COffset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_f32_4x4x4f16
<
MPerXdlops
,
NPerXdlops
,
COffset
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_4x4x4f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
...
@@ -732,7 +682,7 @@ struct XdlopsGemm
...
@@ -732,7 +682,7 @@ struct XdlopsGemm
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
}
}
template
<
index_t
c_offset
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
static_assert
(
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
...
@@ -740,8 +690,7 @@ struct XdlopsGemm
...
@@ -740,8 +690,7 @@ struct XdlopsGemm
"base base_type must be float, half, ushort!"
);
"base base_type must be float, half, ushort!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
});
});
}
}
...
@@ -820,7 +769,8 @@ struct XdlopsGemm
...
@@ -820,7 +769,8 @@ struct XdlopsGemm
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
static
constexpr
auto
KPerXdlops
=
mfma
.
GetKPerXdlops
();
static
constexpr
auto
KPerXdlops
=
mfma
.
GetKPerXdlops
();
static
constexpr
auto
KPerThread
=
mfma
.
GetKPerThread
();
static
constexpr
auto
K1PerXdlops
=
mfma
.
GetKPerThread
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
{
{
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
5ce317cb
...
@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
...
@@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x1f32
;
struct
intrin_mfma_f32_32x32x1f32
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_a
,
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_b
,
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
1
>
{}],
1
,
1
,
0
);
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
1
,
0
);
}
}
};
};
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
,
COffset
>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x2f32
;
struct
intrin_mfma_f32_32x32x2f32
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
,
COffset
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x4f32
;
struct
intrin_mfma_f32_16x16x4f32
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_16x16x4f32
<
16
,
16
,
COffset
>
struct
intrin_mfma_f32_16x16x4f32
<
16
,
16
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
llvm_intrin_amdgcn_mfma_f32_16x16x4f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x1f32
;
struct
intrin_mfma_f32_16x16x1f32
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_16x16x1f32
<
16
,
64
,
COffset
>
struct
intrin_mfma_f32_16x16x1f32
<
16
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
llvm_intrin_amdgcn_mfma_f32_16x16x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x1f32
;
struct
intrin_mfma_f32_4x4x1f32
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
,
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
4
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
}
}
};
};
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
,
COffset
>
struct
intrin_mfma_f32_4x4x1f32
<
8
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_a
,
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_b
,
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
1
>
{}],
4
,
1
,
0
);
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x1f32
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
1
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x4f16
;
struct
intrin_mfma_f32_32x32x4f16
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
64
,
COffset
>
struct
intrin_mfma_f32_32x32x4f16
<
64
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_a
,
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_b
,
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
1
>
{}],
1
,
1
,
0
);
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
1
,
0
);
}
}
};
};
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x4f16
<
32
,
64
,
COffset
>
struct
intrin_mfma_f32_32x32x4f16
<
32
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
llvm_intrin_amdgcn_mfma_f32_32x32x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8f16
;
struct
intrin_mfma_f32_32x32x8f16
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_32x32x8f16
<
32
,
32
,
COffset
>
struct
intrin_mfma_f32_32x32x8f16
<
32
,
32
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
llvm_intrin_amdgcn_mfma_f32_32x32x8f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x16f16
;
struct
intrin_mfma_f32_16x16x16f16
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_16x16x16f16
<
16
,
16
,
COffset
>
struct
intrin_mfma_f32_16x16x16f16
<
16
,
16
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
llvm_intrin_amdgcn_mfma_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x4f16
;
struct
intrin_mfma_f32_16x16x4f16
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_16x16x4f16
<
16
,
64
,
COffset
>
struct
intrin_mfma_f32_16x16x4f16
<
16
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
llvm_intrin_amdgcn_mfma_f32_16x16x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
2
,
0
,
0
);
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
COffset
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_4x4x4f16
;
struct
intrin_mfma_f32_4x4x4f16
;
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_4x4x4f16
<
4
,
64
,
COffset
>
struct
intrin_mfma_f32_4x4x4f16
<
4
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
}
}
};
};
template
<
index_t
COffset
>
template
<
>
struct
intrin_mfma_f32_4x4x4f16
<
8
,
64
,
COffset
>
struct
intrin_mfma_f32_4x4x4f16
<
8
,
64
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
(
Number
<
COffset
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_a
,
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_b
,
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
1
>
{}],
4
,
1
,
0
);
reg_c
[
Number
<
COffset
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
0
,
0
);
reg_c
(
Number
<
COffset
+
1
>
{}).
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_4x4x4f16
(
reg_a
,
reg_b
,
reg_c
[
Number
<
COffset
+
1
>
{}].
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
4
,
1
,
0
);
}
}
};
};
...
@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave>
...
@@ -448,7 +340,6 @@ template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
c_vec16_1_t::VecType reg_c);
template <>
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
const ushort2_t* reg_b,
...
...
composable_kernel/include/utility/config.hpp
View file @
5ce317cb
...
@@ -90,8 +90,8 @@
...
@@ -90,8 +90,8 @@
#endif
#endif
// pass tensor descriptor by value or void*
// pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
0
// merge transformation use magic number division
// merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
...
...
composable_kernel/include/utility/data_type.hpp
View file @
5ce317cb
...
@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion
...
@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
;
struct
NumericLimits
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
};
template
<
>
template
<
>
struct
NumericLimits
<
int32
_t
>
struct
NumericLimits
<
half
_t
>
{
{
__host__
__device__
static
constexpr
int32_t
Min
()
static
constexpr
unsigned
short
binary_min
=
0x0400
;
{
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
return
std
::
numeric_limits
<
int32_t
>::
min
();
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
__host__
__device__
static
constexpr
half_t
Min
()
{
return
as_type
<
half_t
>
(
binary_min
);
}
{
return
std
::
numeric_limits
<
int32_t
>::
max
();
__host__
__device__
static
constexpr
half_t
Max
()
{
return
as_type
<
half_t
>
(
binary_max
);
}
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
as_type
<
half_t
>
(
binary_lowest
);
}
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
View file @
5ce317cb
...
@@ -38,6 +38,10 @@ struct DynamicBuffer
...
@@ -38,6 +38,10 @@ struct DynamicBuffer
return
BufferAddressSpace
;
return
BufferAddressSpace
;
}
}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
...
...
composable_kernel/include/utility/reduction_common.hpp
0 → 100644
View file @
5ce317cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_COMMON_HPP
#define CK_REDUCTION_COMMON_HPP
#include "reduction_enums.hpp"
namespace
ck
{
struct
float_equal_one
{
template
<
class
T
>
__device__
inline
bool
operator
()(
T
x
)
{
return
x
<=
static_cast
<
T
>
(
1.0
f
)
and
x
>=
static_cast
<
T
>
(
1.0
f
);
};
};
struct
float_equal_zero
{
template
<
class
T
>
__device__
inline
bool
operator
()(
T
x
)
{
return
x
<=
static_cast
<
T
>
(
0.0
f
)
and
x
>=
static_cast
<
T
>
(
0.0
f
);
};
};
};
// end of namespace ck
#endif
composable_kernel/include/utility/reduction_enums.hpp
0 → 100644
View file @
5ce317cb
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_ENUMS_HPP
#define CK_REDUCTION_ENUMS_HPP
namespace
ck
{
enum
class
ReduceTensorOp_t
{
ADD
=
0
,
MUL
=
1
,
MIN
=
2
,
MAX
=
3
,
AMAX
=
4
,
AVG
=
5
,
NORM1
=
6
,
NORM2
=
7
,
// MUL_NO_ZEROS = 8,
};
enum
class
NanPropagation_t
{
NOT_PROPAGATE_NAN
=
0
,
PROPAGATE_NAN
=
1
,
};
enum
class
ReduceTensorIndices_t
{
NO_INDICES
=
0
,
FLATTENED_INDICES
=
1
,
};
enum
class
IndicesType_t
{
INDICES_32BIT
=
0
,
INDICES_64BIT
=
1
,
INDICES_16BIT
=
2
,
INDICES_8BIT
=
3
,
};
};
// end of namespace ck
#endif
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment