Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
29a118c6
Commit
29a118c6
authored
Sep 05, 2021
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into merge_use_division_mod
parents
1a43a538
19613902
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
642 additions
and
264 deletions
+642
-264
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...rd_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+129
-0
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+1
-2
composable_kernel/include/tensor_description/tensor_descriptor.hpp
...e_kernel/include/tensor_description/tensor_descriptor.hpp
+3
-4
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+5
-7
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
...include/tensor_operation/threadwise_contraction_dlops.hpp
+18
-24
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+9
-12
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
.../include/tensor_operation/threadwise_tensor_slice_set.hpp
+2
-2
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+25
-34
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
.../tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
+16
-19
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+113
-70
composable_kernel/include/utility/array.hpp
composable_kernel/include/utility/array.hpp
+1
-1
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+11
-0
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+56
-66
composable_kernel/include/utility/tuple.hpp
composable_kernel/include/utility/tuple.hpp
+1
-1
composable_kernel/include/utility/tuple_helper.hpp
composable_kernel/include/utility/tuple_helper.hpp
+1
-3
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+3
-0
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
...ution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
+2
-7
host/driver_offline/CMakeLists.txt
host/driver_offline/CMakeLists.txt
+6
-0
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
...ackward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
+12
-12
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+228
-0
No files found.
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
29a118c6
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
C
*
Y
*
X
;
const
auto
GemmK
=
N
*
Ho
*
Wo
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
29a118c6
...
...
@@ -189,8 +189,7 @@ struct TensorAdaptor
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cv_t
<
remove_reference_t
<
decltype
(
Transforms
{}[
i
])
>>::
IsKnownAtCompileTime
();
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
;
...
...
composable_kernel/include/tensor_description/tensor_descriptor.hpp
View file @
29a118c6
...
...
@@ -185,8 +185,7 @@ struct TensorDescriptor
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cv_t
<
remove_reference_t
<
decltype
(
Transforms
{}[
i
])
>>::
IsKnownAtCompileTime
();
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
&&
...
...
@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
template
<
typename
TensorDesc
>
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cv
_t
<
remove_reference
_t
<
TensorDesc
>
>
::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv
ref
_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
using
TensorCoordinateStep_t
=
decltype
(
make_tensor_coordinate_step
(
TensorDesc
{},
MultiIndex
<
remove_cv
_t
<
remove_reference
_t
<
TensorDesc
>
>
::
GetNumOfDimension
()
>
{}));
TensorDesc
{},
MultiIndex
<
remove_cv
ref
_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
29a118c6
...
...
@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp
View file @
29a118c6
...
...
@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
29a118c6
...
...
@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp
View file @
29a118c6
...
...
@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1
static_assert
(
Buffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv
_t
<
remove_reference
_t
<
OriginIdx
>>
>
::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv
ref
_t
<
OriginIdx
>>::
value
,
"wrong! OriginIdx need to be known at compile-time"
);
// Desc is known at compile-time
constexpr
auto
desc
=
remove_cv
_t
<
remove_reference
_t
<
Desc
>
>
{};
constexpr
auto
desc
=
remove_cv
ref
_t
<
Desc
>
{};
// OriginIdx is known at compile-time
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
29a118c6
...
...
@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcSliceOriginIdx
>>>::
value
,
"wrong! SrcSliceOrigin need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
,
"wrong! SrcSliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
(),
"wrong! SrcBuffer need to be StaticBuffer"
);
// static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
// remove_cv_t<remove_reference_t<SrcData>>>::value,
//"wrong! SrcBuffer data type is wrong");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -421,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! DstDesc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstSliceOriginIdx
>>>::
value
,
"wrong! DstSliceOrigin need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! DstSliceOrigin need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
&&
"wrong! inconsistent type"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
&&
"wrong! inconsistent type"
);
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
constexpr
auto
dst_slice_origin_idx
=
DstSliceOriginIdx
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -742,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
SrcData
>>
>
::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_
cv
ref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -899,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -1315,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstOriginIdx
>>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp
View file @
29a118c6
...
...
@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
SrcData
>>
>
::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_
cv
ref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
// tensor descriptor for src_vector
constexpr
auto
src_vector_tensor_lengths
=
SrcVectorTensorLengths
{};
...
...
@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum_t
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_ref
erence
_t
<
DstData
>>
>
::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_
cv
ref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
// tensor descriptor for dst_vector
constexpr
auto
dst_vector_tensor_lengths
=
DstVectorTensorLengths
{};
...
...
@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstOriginIdx
>>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cv
_t
<
remove_reference
_t
<
SrcDesc
>
>
{};
constexpr
auto
dst_desc
=
remove_cv
_t
<
remove_reference
_t
<
DstDesc
>
>
{};
constexpr
auto
src_desc
=
remove_cv
ref
_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cv
ref
_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
...
...
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
29a118c6
...
...
@@ -225,13 +225,49 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
index_t
src_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_
t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
floa
t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
)),
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
// use fp32 load to mimic fp64 load
if
constexpr
(
N
==
1
)
{
const
float2_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
double
>
(
tmp
);
}
else
if
constexpr
(
N
==
2
)
{
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
double2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
const
float4_t
f32_0
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
const
float4_t
f32_1
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
0
);
vector_type
<
double
,
4
>
tmp
;
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
as_type
<
double2_t
>
(
f32_0
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
1
>
{})
=
as_type
<
double2_t
>
(
f32_1
);
return
tmp
.
AsType
<
double4_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
...
...
@@ -283,25 +319,11 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
}
else
if
constexpr
(
N
==
8
)
{
#if 0
vector_type<half_t, 8> tmp;
tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<half4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t),
0);
return tmp.AsType<half8_t>()(Number<0>{});
#else
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
return
as_type
<
half8_t
>
(
tmp
);
#endif
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
...
...
@@ -433,13 +455,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
// use fp32 store to mimic fp64 store
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
as_type
<
float2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
as_type
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
...
...
@@ -466,6 +509,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
half_t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
half_t
),
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
...
...
@@ -552,49 +638,6 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
half_t
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_store_fp16x4
(
tmp
.
AsType
<
half4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
half_t
),
0
);
}
}
}
template
<
typename
T
,
index_t
N
>
...
...
@@ -720,7 +763,7 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
}
// buffer_load requires:
// 1) p_src_wave must
be in
global memory space
// 1) p_src_wave must
point to
global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
...
...
@@ -754,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
}
// buffer_load requires:
// 1) p_src_wave must
be in
global memory space
// 1) p_src_wave must
point to
global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
...
...
@@ -782,7 +825,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
}
// buffer_store requires:
// 1) p_dst_wave must
be
global memory
// 1) p_dst_wave must
point to
global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
...
...
@@ -816,7 +859,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
// buffer_atomic_add requires:
// 1) p_dst_wave must
be
global memory
// 1) p_dst_wave must
point to
global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
...
...
composable_kernel/include/utility/array.hpp
View file @
29a118c6
...
...
@@ -48,7 +48,7 @@ struct Array<TData, 0>
template
<
typename
X
,
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
using
data_type
=
remove_cv
_t
<
remove_reference
_t
<
X
>
>
;
using
data_type
=
remove_cv
ref
_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...}};
}
...
...
composable_kernel/include/utility/data_type.hpp
View file @
29a118c6
...
...
@@ -73,6 +73,13 @@ struct scalar_type<vector_type<T, N>>
};
//
template
<
>
struct
scalar_type
<
double
>
{
using
type
=
double
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
float
>
{
...
...
@@ -864,6 +871,10 @@ struct vector_type<T, 256>
}
};
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
// fp32
using
float2_t
=
typename
vector_type
<
float
,
2
>::
type
;
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
View file @
29a118c6
...
...
@@ -39,18 +39,15 @@ struct DynamicBuffer
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
...
...
@@ -67,15 +64,14 @@ struct DynamicBuffer
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_return_zero
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
return
amd_buffer_load_invalid_element_return_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
}
}
...
...
@@ -94,18 +90,15 @@ struct DynamicBuffer
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
...
...
@@ -115,7 +108,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
t_per_x
>
(
amd_buffer_store
<
remove_cv
ref
_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
#else
if
(
is_valid_element
)
...
...
@@ -136,70 +129,65 @@ struct DynamicBuffer
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if
constexpr
(
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
,
int8_t
>::
value
)
if
constexpr
(
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
)
{
static_assert
(
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x2_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add "
"implementation"
);
static_assert
((
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add "
"implementation"
);
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8_t
>::
value
)
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int8_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8x2_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8x2_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int16_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
_t
<
remove_reference
_t
<
X
>
>
,
int8x4_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cv
ref
_t
<
T
>
,
int8_t
>::
value
&&
is_same
<
remove_cv
ref
_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
else
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
)
else
if
constexpr
(
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
...
...
@@ -224,18 +212,15 @@ struct DynamicBuffer
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
AtomicAdd
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
...
...
@@ -245,7 +230,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cv
_t
<
remove_reference
_t
<
T
>
>
,
t_per_x
>
(
amd_buffer_atomic_add
<
remove_cv
ref
_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
#else
if
(
is_valid_element
)
...
...
@@ -266,9 +251,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
true
>
{
p
,
element_space_size
};
}
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
template
<
AddressSpaceEnum_t
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
,
typename
X
,
typename
enable_if
<
is_same
<
remove_cvref_t
<
T
>,
remove_cvref_t
<
X
>>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
,
T
invalid_element_value
)
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
,
X
invalid_element_value
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
,
false
>
{
p
,
element_space_size
,
invalid_element_value
};
...
...
composable_kernel/include/utility/tuple.hpp
View file @
29a118c6
...
...
@@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cv
_t
<
remove_reference
_t
<
Xs
>
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cv
ref
_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
}
}
// namespace ck
...
...
composable_kernel/include/utility/tuple_helper.hpp
View file @
29a118c6
...
...
@@ -14,9 +14,7 @@ struct is_known_at_compile_time<Tuple<Ts...>>
return
container_reduce
(
Tuple
<
Ts
...
>
{},
[](
auto
x
,
bool
r
)
{
return
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
decltype
(
x
)
>>>::
value
&
r
;
return
is_known_at_compile_time
<
remove_cvref_t
<
decltype
(
x
)
>>::
value
&
r
;
},
true
);
}
...
...
composable_kernel/include/utility/type.hpp
View file @
29a118c6
...
...
@@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference<T>::type;
template
<
typename
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
...
...
composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp
View file @
29a118c6
...
...
@@ -374,13 +374,8 @@ extern "C" __global__ void
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1
{},
CGridBlockCluster_BlockId_To_GM10_GN10
{}));
const
auto
desc_tuple
=
*
reinterpret_cast
<
const
DescTuple
*>
(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// TODO: how to cast?
(
const
void
*
)
p_desc_tuple
#pragma clang diagnostic pop
);
const
auto
desc_tuple
=
*
reinterpret_cast
<
const
DescTuple
*>
(
cast_pointer_to_generic_address_space
(
p_desc_tuple
));
const
auto
a_grid_desc_gk0_gm0_gm10_gm11_gk1
=
desc_tuple
[
I0
];
const
auto
b_grid_desc_gk0_gn0_gn10_gn11_gk1
=
desc_tuple
[
I1
];
...
...
host/driver_offline/CMakeLists.txt
View file @
29a118c6
...
...
@@ -13,9 +13,15 @@ include_directories(BEFORE
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp
)
set
(
CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp
)
set
(
CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp
)
set
(
GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp
)
add_executable
(
conv_fwd_driver_offline
${
CONV_FWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_bwd_driver_offline
${
CONV_BWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_wrw_driver_offline
${
CONV_WRW_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
gemm_driver_offline
${
GEMM_DRIVER_OFFLINE_SOURCE
}
)
target_link_libraries
(
conv_fwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_bwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_wrw_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
gemm_driver_offline PRIVATE host_tensor
)
host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
View file @
29a118c6
...
...
@@ -208,20 +208,20 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+:
g
emm
k
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+:
g
emm
m
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+:
g
emm
k
1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: Gemm
k
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemm
m
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemm
k
1
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+:
G
emm
K
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+:
G
emm
M
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+:
G
emm
K
1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: Gemm
K
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemm
M
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemm
K
1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+:
g
emm
k
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+:
g
emm
n
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+:
g
emm
k
1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-:
g
emm
k
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-:
g
emm
n
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-:
g
emm
k
1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+:
G
emm
K
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+:
G
emm
N
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+:
G
emm
K
1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-:
G
emm
K
0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-:
G
emm
N
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-:
G
emm
K
1
// clang-format off
constexpr
auto
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
...
...
host/driver_offline/include/device_convolution_
for
ward_implicit_gemm_v4r4r
3
_xdlops_nhw
c
_kyx
c
_nhw
k
.hpp
→
host/driver_offline/include/device_convolution_
back
ward_
weight_
implicit_gemm_v4r4r
2
_xdlops_n
c
hw_k
c
yx_n
k
hw.hpp
View file @
29a118c6
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_
forward
_convolution_into_gemm_v4r4r2_nhw
c
_kyx
c
_nhw
k
.hpp"
#include "transform_
backward_weight
_convolution_into_gemm_v4r4r2_n
c
hw_k
c
yx_n
k
hw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
...
...
@@ -14,17 +14,17 @@ template <typename TInWei,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_
for
ward_implicit_gemm_v4r4r
3
_xdlops_nhw
c
_kyx
c
_nhw
k
(
const
InLengths
&
in_n_hi_wi_
c_
lengths
,
const
WeiLengths
&
wei_k_y_x_
c_
lengths
,
const
OutLengths
&
out_n_ho_wo_
k_
lengths
,
void
device_convolution_
back
ward_
weight_
implicit_gemm_v4r4r
2
_xdlops_n
c
hw_k
c
yx_n
k
hw
(
const
InLengths
&
in_n_
c_
hi_wi_lengths
,
const
WeiLengths
&
wei_k_
c_
y_x_lengths
,
const
OutLengths
&
out_n_
k_
ho_wo_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TInWei
>&
in_n_hi_wi
_c
,
const
Tensor
<
TInWei
>&
wei_k_y_x
_c
,
Tensor
<
TOut
>&
out_n_ho_wo
_k
,
const
Tensor
<
TInWei
>&
in_n_
c_
hi_wi
,
Tensor
<
TInWei
>&
wei_k_
c_
y_x
,
const
Tensor
<
TOut
>&
out_n_
k_
ho_wo
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
...
...
@@ -34,55 +34,21 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_hi_wi_
c_
device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi
_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_
c_
device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x
_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_
k_
device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo
_k
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_
c_
hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_
c_
hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_
c_
y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_
c_
y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_
k_
ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_
k_
ho_wo
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_
c_
device_buf
.
ToDevice
(
in_n_hi_wi
_c
.
mData
.
data
());
wei_k_y_x_
c_
device_buf
.
ToDevice
(
wei_k_y_x
_c
.
mData
.
data
());
out_n_ho_wo_
k_
device_buf
.
ToDevice
(
out_n_ho_wo
_k
.
mData
.
data
());
in_n_
c_
hi_wi_device_buf
.
ToDevice
(
in_n_
c_
hi_wi
.
mData
.
data
());
wei_k_
c_
y_x_device_buf
.
ToDevice
(
wei_k_
c_
y_x
.
mData
.
data
());
out_n_
k_
ho_wo_device_buf
.
ToDevice
(
out_n_
k_
ho_wo
.
mData
.
data
());
const
auto
in_n_hi_wi_
c_
desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_
c_
lengths
);
const
auto
wei_k_y_x_
c_
desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_
c_
lengths
);
const
auto
out_n_ho_wo_
k_
desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_
k_
lengths
);
const
auto
in_n_
c_
hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_
c_
hi_wi_lengths
);
const
auto
wei_k_
c_
y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_
c_
y_x_lengths
);
const
auto
out_n_
k_
ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_
k_
ho_wo_lengths
);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
@@ -91,54 +57,26 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_Gemm
K1
=
8
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_Gemm
N
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [
256
, 128, 4, 8] for fp16
// [M, N, K0, K1] = [
128
, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
...
...
@@ -154,70 +92,73 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_Gemm
K1
=
8
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_Gemm
N
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad
(
wei_k_y_x_
c_
desc
,
in_n_hi_wi_
c_
desc
,
out_n_ho_wo_
k_
desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
const
auto
wei
_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
wei_k
_c
_y_x_desc
,
in_n_
c_
hi_wi_desc
,
out_n_
k_
ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
const
auto
out
_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
out
_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
wei
_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
out
_m0_
m1_m2_n
_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
wei
_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei
_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out
_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
1
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -227,14 +168,15 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
wei
_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out
_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
out
_gemmm_gemmn_grid_desc
),
decltype
(
wei
_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
...
...
@@ -250,53 +192,37 @@ void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmBBlockTransferSrcScalarPerVector_Gemm
K1
,
GemmBBlockTransferSrcScalarPerVector_Gemm
N
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
6
,
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei
_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out
_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
out
_m0_
m1_m2_n
_grid_step_hacks
),
decltype
(
wei
_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei
_m0_
n0_m1_n1_m2_m3_m4_n2
_grid_step_hacks
),
decltype
(
out
_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
wei_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
out_m0_m1_m2_n_grid_step_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Hi
=
in_n_hi_wi_c_lengths
[
I1
];
const
auto
Wi
=
in_n_hi_wi_c_lengths
[
I2
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
false
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
out_n_ho_wo_k
_device_buf
.
FromDevice
(
out_n_ho_wo_k
.
mData
.
data
());
wei_k_c_y_x
_device_buf
.
FromDevice
(
wei_k_c_y_x
.
mData
.
data
());
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment