Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a84f254a
Commit
a84f254a
authored
Dec 06, 2021
by
Chao Liu
Browse files
added conv+bias+relu
parent
8159be33
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
588 additions
and
282 deletions
+588
-282
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp
+1
-19
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp
+8
-72
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+0
-1
device_operation/device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
..._conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
+15
-15
device_operation/include/device_conv2d_fwd_xdl_bias_activation_nhwc_kyxc_nhwk.hpp
.../device_conv2d_fwd_xdl_bias_activation_nhwc_kyxc_nhwk.hpp
+2
-35
device_operation/include/device_conv_fwd_bias_activation.hpp
device_operation/include/device_conv_fwd_bias_activation.hpp
+0
-1
device_operation/include/element_wise_operation.hpp
device_operation/include/element_wise_operation.hpp
+101
-0
example/1_gemm_xdl/gemm_xdl.cpp
example/1_gemm_xdl/gemm_xdl.cpp
+9
-26
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
+1
-1
example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp
...u_add/include/device_gemm_xdl_two_extra_source_reduce.hpp
+18
-0
example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp
example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp
+4
-22
example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp
...e/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp
+3
-15
example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp
...2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp
+0
-74
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+13
-1
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
+293
-0
profiler/profile_conv_fwd_bias_relu.cpp
profiler/profile_conv_fwd_bias_relu.cpp
+114
-0
profiler/profiler.cpp
profiler/profiler.cpp
+6
-0
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp
View file @
a84f254a
...
@@ -19,7 +19,6 @@ template <typename GridwiseGemm,
...
@@ -19,7 +19,6 @@ template <typename GridwiseGemm,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
...
@@ -34,12 +33,10 @@ __global__ void
...
@@ -34,12 +33,10 @@ __global__ void
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC
*
__restrict__
p_c0_grid
,
const
FloatC
*
__restrict__
p_c0_grid
,
const
FloatC
*
__restrict__
p_c1_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
...
@@ -54,13 +51,11 @@ __global__ void
...
@@ -54,13 +51,11 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_c0_grid
,
p_c0_grid
,
p_c1_grid
,
p_shared_block
,
p_shared_block
,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
...
@@ -76,7 +71,6 @@ template <index_t BlockSize,
...
@@ -76,7 +71,6 @@ template <index_t BlockSize,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
...
@@ -326,9 +320,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
...
@@ -326,9 +320,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
using
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C0GridDesc_M_N
{}));
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C0GridDesc_M_N
{}));
using
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C1GridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
template
<
bool
HasMainKBlockLoop
>
...
@@ -337,13 +328,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
...
@@ -337,13 +328,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC
*
__restrict__
p_c0_grid
,
const
FloatC
*
__restrict__
p_c0_grid
,
const
FloatC
*
__restrict__
p_c1_grid
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
@@ -359,9 +348,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
...
@@ -359,9 +348,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c0_grid
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
p_c0_grid
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c1_grid
,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// divide block work by [M, N]
// divide block work by [M, N]
...
@@ -615,7 +601,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
...
@@ -615,7 +601,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
CElementwiseOperation
,
CElementwiseOperation
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
...
@@ -626,7 +611,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
...
@@ -626,7 +611,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
true
>
{
true
>
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I1
],
...
@@ -644,9 +628,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
...
@@ -644,9 +628,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
c_grid_buf
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c0_grid_buf
,
c0_grid_buf
);
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c1_grid_buf
);
}
}
}
}
};
// namespace ck
};
// namespace ck
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r5.hpp
View file @
a84f254a
...
@@ -31,7 +31,6 @@ template <typename SrcData,
...
@@ -31,7 +31,6 @@ template <typename SrcData,
typename
SrcDesc
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DstDesc
,
typename
Dst0Desc
,
// this is really one of sources, but it has same shape as DstDesc
typename
Dst0Desc
,
// this is really one of sources, but it has same shape as DstDesc
typename
Dst1Desc
,
// this is really one of sources, but it has same shape as DstDesc
typename
DstElementwiseOperation
,
typename
DstElementwiseOperation
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
...
@@ -49,21 +48,17 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -49,21 +48,17 @@ struct ThreadwiseTensorSliceTransfer_v1r5
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
Dst0Coord
=
decltype
(
make_tensor_coordinate
(
Dst0Desc
{},
Index
{}));
using
Dst0Coord
=
decltype
(
make_tensor_coordinate
(
Dst0Desc
{},
Index
{}));
using
Dst1Coord
=
decltype
(
make_tensor_coordinate
(
Dst1Desc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
using
Dst0CoordStep
=
decltype
(
make_tensor_coordinate_step
(
Dst0Desc
{},
Index
{}));
using
Dst0CoordStep
=
decltype
(
make_tensor_coordinate_step
(
Dst0Desc
{},
Index
{}));
using
Dst1CoordStep
=
decltype
(
make_tensor_coordinate_step
(
Dst1Desc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r5
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r5
(
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Dst0Desc
&
dst0_desc
,
const
Dst0Desc
&
dst0_desc
,
const
Dst1Desc
&
dst1_desc
,
const
Index
&
dst_slice_origin_idx
,
const
Index
&
dst_slice_origin_idx
,
const
DstElementwiseOperation
&
dst_element_op
)
const
DstElementwiseOperation
&
dst_element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
dst0_coord_
(
make_tensor_coordinate
(
dst0_desc
,
dst_slice_origin_idx
)),
dst0_coord_
(
make_tensor_coordinate
(
dst0_desc
,
dst_slice_origin_idx
)),
dst1_coord_
(
make_tensor_coordinate
(
dst1_desc
,
dst_slice_origin_idx
)),
dst_element_op_
{
dst_element_op
}
dst_element_op_
{
dst_element_op
}
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
...
@@ -79,10 +74,8 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -79,10 +74,8 @@ struct ThreadwiseTensorSliceTransfer_v1r5
typename
SrcBuffer
,
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstBuffer
,
typename
Dst0Buffer
,
typename
Dst0Buffer
,
typename
Dst1Buffer
,
typename
DstStepHacks
,
typename
DstStepHacks
,
typename
Dst0StepHacks
,
typename
Dst0StepHacks
>
typename
Dst1StepHacks
>
__device__
void
Run
(
const
SrcDesc
&
,
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
...
@@ -91,10 +84,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -91,10 +84,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
const
DstStepHacks
&
dst_step_hacks
,
const
DstStepHacks
&
dst_step_hacks
,
const
Dst0Desc
&
dst0_desc
,
const
Dst0Desc
&
dst0_desc
,
const
Dst0Buffer
&
dst0_buf
,
const
Dst0Buffer
&
dst0_buf
,
const
Dst0StepHacks
&
dst0_step_hacks
,
const
Dst0StepHacks
&
dst0_step_hacks
)
const
Dst1Desc
&
dst1_desc
,
const
Dst1Buffer
&
dst1_buf
,
const
Dst1StepHacks
&
dst1_step_hacks
)
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
...
@@ -156,22 +146,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -156,22 +146,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const
auto
dst1_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst1_desc
,
forward_step_idx
,
dst1_step_hacks
[
I0
][
i
]);
},
Number
<
nDim
>
{});
// make backward steps: dst
// make backward steps: dst
const
auto
dst_backward_steps
=
generate_tuple
(
const
auto
dst_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -202,22 +176,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -202,22 +176,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
},
},
Number
<
nDim
>
{});
Number
<
nDim
>
{});
// make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
const
auto
dst1_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst1_desc
,
backward_step_idx
,
dst1_step_hacks
[
I1
][
i
]);
},
Number
<
nDim
>
{});
// loop over tensor and copy
// loop over tensor and copy
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// judge move forward or move backward
// judge move forward or move backward
...
@@ -258,7 +216,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -258,7 +216,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
using
dst_vector_t
=
using
dst_vector_t
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
// load dst0 and
dst1 and
apply elementwise operation
// load dst0 and apply elementwise operation
{
{
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
// TODO: fix this
// TODO: fix this
...
@@ -270,29 +228,22 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -270,29 +228,22 @@ struct ThreadwiseTensorSliceTransfer_v1r5
const
SrcData
src_v
=
src_buf
[
Number
<
src_offset
>
{}];
const
SrcData
src_v
=
src_buf
[
Number
<
src_offset
>
{}];
// load dst0
and dst1
// load dst0
const
bool
is_dst0_valid
=
const
bool
is_dst0_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst0_desc
,
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst0_desc
,
dst0_coord_
);
dst0_coord_
);
const
bool
is_dst1_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst1_desc
,
dst1_coord_
);
const
DstData
dst0_v
=
const
DstData
dst0_v
=
dst0_buf
.
template
Get
<
DstData
>(
dst0_coord_
.
GetOffset
(),
is_dst0_valid
);
dst0_buf
.
template
Get
<
DstData
>(
dst0_coord_
.
GetOffset
(),
is_dst0_valid
);
const
DstData
dst1_v
=
dst1_buf
.
template
Get
<
DstData
>(
dst1_coord_
.
GetOffset
(),
is_dst1_valid
);
#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE
#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE
// apply element-wise operation in SrcData type
// apply element-wise operation in SrcData type
const
SrcData
dst_v
=
dst_element_op_
(
const
SrcData
dst_v
=
dst_element_op_
(
src_v
,
type_convert
<
SrcData
>
(
dst0_v
));
src_v
,
type_convert
<
SrcData
>
(
dst0_v
),
type_convert
<
SrcData
>
(
dst1_v
));
// apply type convert
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
type_convert
<
DstData
>
(
dst_v
);
dst_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
type_convert
<
DstData
>
(
dst_v
);
#else
#else
// apply element-wise operation in DstData type
// apply element-wise operation in DstData type
const
DstData
dst_v
=
dst_element_op_
(
src_v
,
dst0_v
,
dst1_v
);
const
DstData
dst_v
=
dst_element_op_
(
src_v
,
dst0_v
);
dst_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
dst_v
;
dst_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
dst_v
;
#endif
#endif
...
@@ -361,10 +312,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -361,10 +312,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
// dst0
// dst0
move_tensor_coordinate
(
move_tensor_coordinate
(
dst0_desc
,
dst0_coord_
,
dst0_forward_steps
[
dim_access_order
[
i
]]);
dst0_desc
,
dst0_coord_
,
dst0_forward_steps
[
dim_access_order
[
i
]]);
// dst1
move_tensor_coordinate
(
dst1_desc
,
dst1_coord_
,
dst1_forward_steps
[
dim_access_order
[
i
]]);
}
}
else
else
{
{
...
@@ -374,10 +321,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -374,10 +321,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
// dst0
// dst0
move_tensor_coordinate
(
move_tensor_coordinate
(
dst0_desc
,
dst0_coord_
,
dst0_backward_steps
[
dim_access_order
[
i
]]);
dst0_desc
,
dst0_coord_
,
dst0_backward_steps
[
dim_access_order
[
i
]]);
// dst1
move_tensor_coordinate
(
dst1_desc
,
dst1_coord_
,
dst1_backward_steps
[
dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -397,7 +340,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -397,7 +340,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
typename
SrcBuffer
,
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstBuffer
,
typename
Dst0Buffer
,
typename
Dst0Buffer
,
typename
Dst1Buffer
,
typename
DstStepHacks
>
typename
DstStepHacks
>
__device__
void
Run
(
const
SrcDesc
&
,
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcSliceOriginIdx
&
,
...
@@ -406,9 +348,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -406,9 +348,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
const
DstStepHacks
&
dst_step_hacks
,
const
DstStepHacks
&
dst_step_hacks
,
const
Dst0Desc
&
dst0_desc
,
const
Dst0Desc
&
dst0_desc
,
const
Dst0Buffer
&
dst0_buf
,
const
Dst0Buffer
&
dst0_buf
)
const
Dst1Desc
&
dst1_desc
,
const
Dst1Buffer
&
dst1_buf
)
{
{
auto
f_step_hacks
=
[
&
](
auto
desc
)
{
auto
f_step_hacks
=
[
&
](
auto
desc
)
{
constexpr
index_t
ntransform
=
decltype
(
desc
)
::
GetNumOfTransform
();
constexpr
index_t
ntransform
=
decltype
(
desc
)
::
GetNumOfTransform
();
...
@@ -430,10 +370,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -430,10 +370,7 @@ struct ThreadwiseTensorSliceTransfer_v1r5
dst_step_hacks
,
dst_step_hacks
,
dst0_desc
,
dst0_desc
,
dst0_buf
,
dst0_buf
,
f_step_hacks
(
dst0_desc
),
f_step_hacks
(
dst0_desc
));
dst1_desc
,
dst1_buf
,
f_step_hacks
(
dst1_desc
));
}
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
...
@@ -514,7 +451,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
...
@@ -514,7 +451,6 @@ struct ThreadwiseTensorSliceTransfer_v1r5
private:
private:
DstCoord
dst_coord_
;
DstCoord
dst_coord_
;
Dst0Coord
dst0_coord_
;
Dst0Coord
dst0_coord_
;
Dst1Coord
dst1_coord_
;
const
DstElementwiseOperation
dst_element_op_
;
const
DstElementwiseOperation
dst_element_op_
;
};
// namespace ck
};
// namespace ck
...
...
composable_kernel/include/utility/config.hpp
View file @
a84f254a
...
@@ -145,7 +145,6 @@
...
@@ -145,7 +145,6 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1
#endif
#endif
namespace
ck
{
namespace
ck
{
enum
InMemoryDataOperationEnum_t
enum
InMemoryDataOperationEnum_t
...
...
device_operation/device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp
View file @
a84f254a
...
@@ -15,7 +15,7 @@ template <ck::index_t... Is>
...
@@ -15,7 +15,7 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddRelu
Add
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
Add
;
using
AddRelu
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
using
device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
using
device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
...
@@ -23,24 +23,24 @@ using device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple<
...
@@ -23,24 +23,24 @@ using device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple<
//################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//################################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
Add
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances
(
void
add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances
(
std
::
vector
<
DeviceConvFwdBiasActivationPtr
<
PassThrough
,
PassThrough
,
AddRelu
Add
>>&
std
::
vector
<
DeviceConvFwdBiasActivationPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instance_container
)
instance_container
)
{
{
using
Instances
=
device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances
;
using
Instances
=
device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instances
;
...
...
device_operation/include/device_conv2d_fwd_xdl_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
a84f254a
...
@@ -17,7 +17,7 @@ namespace tensor_operation {
...
@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
// out[N, Ho, Wo, K] =
// out[N, Ho, Wo, K] =
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K])
+ residual[N, Ho, Wo, K]
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K])
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
OutDataType
,
...
@@ -209,14 +209,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -209,14 +209,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
auto
bias_grid_desc_gemmm_gemmn
=
const
auto
bias_grid_desc_gemmm_gemmn
=
make_naive_tensor_descriptor
(
make_tuple
(
GemmM
,
GemmN
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
GemmM
,
GemmN
),
make_tuple
(
I0
,
I1
));
// C1: residual tensor: assume same layout as output tensor
const
auto
resi_grid_desc_gemmm_gemmn
=
out_gemmm_gemmn_grid_desc
;
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemmm_gemmn_grid_desc
,
bias_grid_desc_gemmm_gemmn
,
bias_grid_desc_gemmm_gemmn
);
resi_grid_desc_gemmm_gemmn
);
}
}
using
ABCGridDescs
=
decltype
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
using
ABCGridDescs
=
decltype
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
...
@@ -226,7 +222,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -226,7 +222,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
C0GridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I3
])
>
;
using
C0GridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I3
])
>
;
using
C1GridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I4
])
>
;
// TODO remove these hacks
// TODO remove these hacks
static
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
static
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
...
@@ -279,7 +274,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -279,7 +274,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
C0GridDesc_M_N
,
C0GridDesc_M_N
,
C1GridDesc_M_N
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
...
@@ -325,9 +319,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -325,9 +319,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C0GridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C0GridDesc_M_N
{}));
using
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C1GridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
// Argument
// Argument
...
@@ -337,7 +328,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -337,7 +328,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
WeiDataType
*
p_wei_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_bias_grid
,
const
OutDataType
*
p_bias_grid
,
const
OutDataType
*
p_resi_grid
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -357,15 +347,12 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -357,15 +347,12 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
p_b_grid_
{
p_wei_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_out_grid
},
p_c_grid_
{
p_out_grid
},
p_c0_grid_
{
p_bias_grid
},
p_c0_grid_
{
p_bias_grid
},
p_c1_grid_
{
p_resi_grid
},
a_grid_desc_k0_m_k1_
{},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
c0_grid_desc_m_n_
{},
c0_grid_desc_m_n_
{},
c1_grid_desc_m_n_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -389,7 +376,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -389,7 +376,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
c_grid_desc_m_n_
=
descs
[
I2
];
c0_grid_desc_m_n_
=
descs
[
I3
];
c0_grid_desc_m_n_
=
descs
[
I3
];
c1_grid_desc_m_n_
=
descs
[
I4
];
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
...
@@ -400,9 +386,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -400,9 +386,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c0_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c0_grid_desc_m_n_
);
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c1_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
}
}
...
@@ -412,15 +395,12 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -412,15 +395,12 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
const
CDataType
*
p_c0_grid_
;
const
CDataType
*
p_c0_grid_
;
const
CDataType
*
p_c1_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -450,9 +430,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -450,9 +430,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std
::
cout
<<
"arg.c0_grid_desc_m_n_{ "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.c0_grid_desc_m_n_{ "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
", "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c1_grid_desc_m_n_{ "
<<
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
...
@@ -483,7 +460,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -483,7 +460,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceOp
::
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceOp
::
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceOp
::
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
...
@@ -499,12 +475,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -499,12 +475,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
@@ -520,7 +494,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -520,7 +494,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceOp
::
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceOp
::
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceOp
::
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
...
@@ -536,12 +509,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -536,12 +509,10 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
@@ -581,7 +552,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -581,7 +552,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
WeiDataType
*
p_wei_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_bias_grid
,
const
OutDataType
*
p_bias_grid
,
const
OutDataType
*
p_resi_grid
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -600,7 +570,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -600,7 +570,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
p_wei_grid
,
p_wei_grid
,
p_out_grid
,
p_out_grid
,
p_bias_grid
,
p_bias_grid
,
p_resi_grid
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -625,7 +594,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -625,7 +594,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
void
*
p_wei_grid
,
const
void
*
p_wei_grid
,
void
*
p_out_grid
,
void
*
p_out_grid
,
const
void
*
p_bias_grid
,
const
void
*
p_bias_grid
,
const
void
*
p_resi_grid
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -644,7 +612,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -644,7 +612,6 @@ struct DeviceConv2dFwdXdl_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
OutDataType
*>
(
p_out_grid
),
static_cast
<
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
OutDataType
*>
(
p_bias_grid
),
static_cast
<
const
OutDataType
*>
(
p_bias_grid
),
static_cast
<
const
OutDataType
*>
(
p_resi_grid
),
N
,
N
,
K
,
K
,
C
,
C
,
...
...
device_operation/include/device_conv_fwd_bias_activation.hpp
View file @
a84f254a
...
@@ -18,7 +18,6 @@ struct DeviceConvFwdBiasActivation : public BaseOperator
...
@@ -18,7 +18,6 @@ struct DeviceConvFwdBiasActivation : public BaseOperator
const
void
*
p_wei
,
const
void
*
p_wei
,
void
*
p_out
,
void
*
p_out
,
const
void
*
p_bias
,
const
void
*
p_bias
,
const
void
*
p_resi
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
...
device_operation/include/element_wise_operation.hpp
View file @
a84f254a
...
@@ -14,6 +14,34 @@ struct PassThrough
...
@@ -14,6 +14,34 @@ struct PassThrough
}
}
};
};
struct
AddRelu
{
template
<
typename
T1
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
}
template
<
typename
T1
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
#endif
}
};
struct
AddReluAdd
struct
AddReluAdd
{
{
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
...
@@ -44,6 +72,79 @@ struct AddReluAdd
...
@@ -44,6 +72,79 @@ struct AddReluAdd
}
}
};
};
struct
AddLeakyReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif
0
// this spill register
float
a
=
v0
+
v1
;
float
b
=
float
(
0.1
)
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#elif 0
// this use lots of registers (but no spill)
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this spill registers, 89 Tflops
float
a
=
v0
+
v1
;
float
alpha
=
0.1
;
float
b
;
asm
volatile
(
"
\n
\
v_mul_f32_e32 %0, %1, %2
\n
\
"
:
"=v"
(
b
)
:
"s"
(
alpha
),
"v"
(
a
));
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#endif
}
};
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
...
...
example/1_gemm_xdl/gemm_xdl.cpp
View file @
a84f254a
...
@@ -13,24 +13,7 @@
...
@@ -13,24 +13,7 @@
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
struct
Relu
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
>
0
?
v
:
0
;
}
};
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -44,9 +27,9 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -44,9 +27,9 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
A
Op
=
PassThrough
;
using
A
ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B
Op
=
PassThrough
;
using
B
ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
C
Op
=
Relu
;
using
C
ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for NT problem
// Compilation parameters for NT problem
// clang-format off
// clang-format off
...
@@ -55,7 +38,7 @@ using DeviceGemmInstance =
...
@@ -55,7 +38,7 @@ using DeviceGemmInstance =
//#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AOp
,
BOp
,
C
Op
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElement
Op
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
// clang-format on
// clang-format on
template
<
typename
AType
,
template
<
typename
AType
,
...
@@ -189,9 +172,9 @@ int main(int argc, char* argv[])
...
@@ -189,9 +172,9 @@ int main(int argc, char* argv[])
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
AOp
{},
A
Element
Op
{},
BOp
{},
B
Element
Op
{},
COp
{});
C
Element
Op
{});
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -217,7 +200,7 @@ int main(int argc, char* argv[])
...
@@ -217,7 +200,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
host_verify
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
A
Op
{},
BOp
{},
C
Op
{});
host_verify
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
A
ElementOp
{},
BElementOp
{},
CElement
Op
{});
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
}
}
...
...
example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp
View file @
a84f254a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
#include "host_gemm.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "example/
2
_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp"
#include "example/
3
_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp"
// C[m, n] = Relu(A[m, k] * B[k, n] + C0[m]) + C1[m, n]
// C[m, n] = Relu(A[m, k] * B[k, n] + C0[m]) + C1[m, n]
// assume C0 is contiguous in memory
// assume C0 is contiguous in memory
...
...
example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp
View file @
a84f254a
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define DEVICE_GEMM_XDL_TWO_EXTRA_SOURCE_REDUCE_HPP
#define DEVICE_GEMM_XDL_TWO_EXTRA_SOURCE_REDUCE_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "device_gemm.hpp"
#include "device_gemm.hpp"
...
@@ -560,6 +561,23 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
...
@@ -560,6 +561,23 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmXdl_two_extra_source_reduce"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
};
}
// namespace device
}
// namespace device
...
...
example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp
View file @
a84f254a
...
@@ -12,25 +12,7 @@
...
@@ -12,25 +12,7 @@
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
struct
Relu
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
T
tmp
=
0.1
*
v
;
return
tmp
>
0
?
tmp
:
0
;
}
};
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
...
@@ -44,9 +26,9 @@ using InLayout = ck::tensor_layout::convolution::NHWC;
...
@@ -44,9 +26,9 @@ using InLayout = ck::tensor_layout::convolution::NHWC;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KYXC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KYXC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHWK
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHWK
;
using
InElementOp
=
PassThrough
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
Relu
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceConvFwdInstance
=
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
ck
::
tensor_operation
::
device
::
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
...
example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp
View file @
a84f254a
...
@@ -28,7 +28,7 @@ using OutLayout = ck::tensor_layout::convolution::NHWK;
...
@@ -28,7 +28,7 @@ using OutLayout = ck::tensor_layout::convolution::NHWK;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
Add
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
// clang-format off
// clang-format off
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
...
@@ -50,7 +50,6 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
...
@@ -50,7 +50,6 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
const
Tensor
<
TWei
>&
wei_k_c_y_x
,
const
Tensor
<
TWei
>&
wei_k_c_y_x
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
const
Tensor
<
TOut
>&
bias_k
,
const
Tensor
<
TOut
>&
bias_k
,
const
Tensor
<
TOut
>&
resi_n_k_ho_wo
,
const
std
::
vector
<
ck
::
index_t
>&
conv_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
conv_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
in_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
in_left_pads
,
...
@@ -79,7 +78,7 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
...
@@ -79,7 +78,7 @@ void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
}
}
}
}
out_n_k_ho_wo
(
n
,
k
,
ho
,
wo
)
=
out_element_op
(
v
,
bias_k
(
k
)
,
resi_n_k_ho_wo
(
n
,
k
,
ho
,
wo
)
);
out_n_k_ho_wo
(
n
,
k
,
ho
,
wo
)
=
out_element_op
(
v
,
bias_k
(
k
));
};
};
make_ParallelTensorFunctor
(
f_nchw
,
make_ParallelTensorFunctor
(
f_nchw
,
...
@@ -198,14 +197,10 @@ int main(int argc, char* argv[])
...
@@ -198,14 +197,10 @@ int main(int argc, char* argv[])
Tensor
<
OutDataType
>
bias_k
(
Tensor
<
OutDataType
>
bias_k
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
K
)})));
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
K
)})));
// residual: assume same layout as output tensor
Tensor
<
OutDataType
>
resi_n_k_ho_wo
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
,
OutLayout
{}));
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"bias_k: "
<<
bias_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"bias_k: "
<<
bias_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"resi_n_k_ho_wo: "
<<
resi_n_k_ho_wo
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -214,13 +209,11 @@ int main(int argc, char* argv[])
...
@@ -214,13 +209,11 @@ int main(int argc, char* argv[])
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
resi_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
0.0
,
1.0
});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
0.0
,
1.0
});
resi_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
0.0
,
1.0
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
...
@@ -228,12 +221,10 @@ int main(int argc, char* argv[])
...
@@ -228,12 +221,10 @@ int main(int argc, char* argv[])
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
());
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias_k
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias_k
.
mDesc
.
GetElementSpace
());
DeviceMem
resi_device_buf
(
sizeof
(
OutDataType
)
*
resi_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias_k
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias_k
.
mData
.
data
());
resi_device_buf
.
ToDevice
(
resi_n_k_ho_wo
.
mData
.
data
());
auto
conv
=
DeviceConvFwdInstance
{};
auto
conv
=
DeviceConvFwdInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
invoker
=
conv
.
MakeInvoker
();
...
@@ -242,7 +233,6 @@ int main(int argc, char* argv[])
...
@@ -242,7 +233,6 @@ int main(int argc, char* argv[])
static_cast
<
const
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
OutDataType
*>
(
bias_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
OutDataType
*>
(
bias_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
OutDataType
*>
(
resi_device_buf
.
GetDeviceBuffer
()),
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -270,8 +260,7 @@ int main(int argc, char* argv[])
...
@@ -270,8 +260,7 @@ int main(int argc, char* argv[])
std
::
size_t
num_btype
=
sizeof
(
InDataType
)
*
(
N
*
C
*
Hi
*
Wi
)
+
std
::
size_t
num_btype
=
sizeof
(
InDataType
)
*
(
N
*
C
*
Hi
*
Wi
)
+
sizeof
(
WeiDataType
)
*
(
K
*
C
*
Y
*
X
)
+
sizeof
(
WeiDataType
)
*
(
K
*
C
*
Y
*
X
)
+
sizeof
(
OutDataType
)
*
(
N
*
K
*
Ho
*
Wo
)
+
sizeof
(
OutDataType
)
*
(
K
)
+
sizeof
(
OutDataType
)
*
(
N
*
K
*
Ho
*
Wo
)
+
sizeof
(
OutDataType
)
*
(
K
);
sizeof
(
OutDataType
)
*
(
N
*
K
*
Ho
*
Wo
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -286,7 +275,6 @@ int main(int argc, char* argv[])
...
@@ -286,7 +275,6 @@ int main(int argc, char* argv[])
wei_k_c_y_x
,
wei_k_c_y_x
,
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_host_result
,
bias_k
,
bias_k
,
resi_n_k_ho_wo
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp
View file @
a84f254a
...
@@ -14,80 +14,6 @@
...
@@ -14,80 +14,6 @@
#include "device_conv2d_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp"
#include "device_conv2d_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
struct
AddLeakyReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif
0
// this spill register
float
a
=
v0
+
v1
;
float
b
=
float
(
0.1
)
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#elif 0
// this use lots of registers (but no spill)
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this spill registers, 89 Tflops
float
a
=
v0
+
v1
;
float
alpha
=
0.1
;
float
b
;
asm
volatile
(
"
\n
\
v_mul_f32_e32 %0, %1, %2
\n
\
"
:
"=v"
(
b
)
:
"s"
(
alpha
),
"v"
(
a
));
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#endif
}
};
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
...
...
profiler/CMakeLists.txt
View file @
a84f254a
...
@@ -44,6 +44,17 @@ target_compile_features(device_conv2d_fwd_instance PUBLIC)
...
@@ -44,6 +44,17 @@ target_compile_features(device_conv2d_fwd_instance PUBLIC)
set_target_properties
(
device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib
)
# device_conv2d_fwd_bias_relu_instance
set
(
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv2d_fwd_xdl_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library
(
device_conv2d_fwd_bias_relu_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
}
)
target_include_directories
(
device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_compile_features
(
device_conv2d_fwd_bias_relu_instance PUBLIC
)
set_target_properties
(
device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib
)
# device_conv2d_fwd_bias_relu_add_instance
# device_conv2d_fwd_bias_relu_add_instance
set
(
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
set
(
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv2d_fwd_xdl_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
...
@@ -56,10 +67,11 @@ set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITI
...
@@ -56,10 +67,11 @@ set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITI
install
(
TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib
)
# ck_profiler
# ck_profiler
set
(
PROFILER_SOURCE profiler.cpp profile_gemm.cpp profile_conv_fwd.cpp profile_conv_fwd_bias_relu_add.cpp
)
set
(
PROFILER_SOURCE profiler.cpp profile_gemm.cpp profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
0 → 100644
View file @
a84f254a
#pragma once
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "device_conv_fwd_bias_activation.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_instance
{
using
DeviceConvFwdBiasReluPtr
=
DeviceConvFwdBiasActivationPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddRelu
>
;
void
add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances
(
std
::
vector
<
DeviceConvFwdBiasReluPtr
>&
);
}
// namespace device_conv2d_fwd_bias_activation_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
InElementOp
,
typename
WeiElementOp
,
typename
OutElementOp
>
void
host_reference_calculation
(
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
const
Tensor
<
TWei
>&
wei_k_c_y_x
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
const
Tensor
<
TOut
>&
bias_k
,
const
std
::
vector
<
ck
::
index_t
>&
conv_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
in_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
/* in_right_pads */
,
const
InElementOp
&
in_element_op
,
const
WeiElementOp
&
wei_element_op
,
const
OutElementOp
&
out_element_op
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei_k_c_y_x
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
for
(
int
y
=
0
;
y
<
wei_k_c_y_x
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
int
hi
=
ho
*
conv_strides
[
0
]
+
y
*
conv_dilations
[
0
]
-
in_left_pads
[
0
];
for
(
int
x
=
0
;
x
<
wei_k_c_y_x
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
int
wi
=
wo
*
conv_strides
[
1
]
+
x
*
conv_dilations
[
1
]
-
in_left_pads
[
1
];
if
(
hi
>=
0
&&
hi
<
in_n_c_hi_wi
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in_n_c_hi_wi
.
mDesc
.
GetLengths
()[
3
])
{
v
+=
in_element_op
(
static_cast
<
const
double
>
(
in_n_c_hi_wi
(
n
,
c
,
hi
,
wi
)))
*
wei_element_op
(
static_cast
<
const
double
>
(
wei_k_c_y_x
(
k
,
c
,
y
,
x
)));
}
}
}
}
out_n_k_ho_wo
(
n
,
k
,
ho
,
wo
)
=
out_element_op
(
v
,
bias_k
(
k
));
};
make_ParallelTensorFunctor
(
f_nchw
,
out_n_k_ho_wo
.
mDesc
.
GetLengths
()[
0
],
out_n_k_ho_wo
.
mDesc
.
GetLengths
()[
1
],
out_n_k_ho_wo
.
mDesc
.
GetLengths
()[
2
],
out_n_k_ho_wo
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
int
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
void
profile_conv_fwd_bias_relu_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
int
nrepeat
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
const
ck
::
index_t
Y
=
filter_spatial_lengths
[
0
];
const
ck
::
index_t
X
=
filter_spatial_lengths
[
1
];
const
ck
::
index_t
Hi
=
input_spatial_lengths
[
0
];
const
ck
::
index_t
Wi
=
input_spatial_lengths
[
1
];
const
ck
::
index_t
Ho
=
output_spatial_lengths
[
0
];
const
ck
::
index_t
Wo
=
output_spatial_lengths
[
1
];
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
H
,
std
::
size_t
W
,
auto
layout
)
{
if
constexpr
(
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCHW
>::
value
||
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
KCYX
>::
value
||
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NKHW
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N_
,
C_
,
H
,
W
}),
std
::
vector
<
std
::
size_t
>
({
C_
*
H
*
W
,
H
*
W
,
W
,
1
}));
}
else
if
constexpr
(
is_same
<
decltype
(
layout
),
tensor_layout
::
convolution
::
NHWC
>::
value
||
is_same
<
decltype
(
layout
),
tensor_layout
::
convolution
::
KYXC
>::
value
||
is_same
<
decltype
(
layout
),
tensor_layout
::
convolution
::
NHWK
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N_
,
C_
,
H
,
W
}),
std
::
vector
<
std
::
size_t
>
({
C_
*
H
*
W
,
1
,
W
*
C_
,
C_
}));
}
};
Tensor
<
InDataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
,
InLayout
{}));
Tensor
<
WeiDataType
>
wei_k_c_y_x
(
f_host_tensor_descriptor
(
K
,
C
,
Y
,
X
,
WeiLayout
{}));
Tensor
<
OutDataType
>
out_n_k_ho_wo_host_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
,
OutLayout
{}));
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
,
OutLayout
{}));
// bias: assume contiguous 1d vector
Tensor
<
OutDataType
>
bias_k
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
K
)})));
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"bias_k: "
<<
bias_k
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
break
;
default:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
0.0
,
1.0
});
}
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
if
(
do_verification
)
{
host_reference_calculation
(
in_n_c_hi_wi
,
wei_k_c_y_x
,
out_n_k_ho_wo_host_result
,
bias_k
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias_k
.
mDesc
.
GetElementSpace
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias_k
.
mData
.
data
());
using
DeviceConvFwdBiasReluPtr
=
ck
::
tensor_operation
::
device
::
DeviceConvFwdBiasActivationPtr
<
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
// add device operator instances
std
::
vector
<
DeviceConvFwdBiasReluPtr
>
op_ptrs
;
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_bias_activation_instance
::
add_device_conv2d_fwd_bias_relu_xdl_nhwc_kyxc_nhwk_fp16_instances
(
op_ptrs
);
}
if
(
op_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
}
std
::
string
best_conv_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device Conv instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
const
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
OutDataType
*>
(
bias_device_buf
.
GetDeviceBuffer
()),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
conv_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
std
::
size_t
num_btype
=
sizeof
(
InDataType
)
*
(
N
*
C
*
Hi
*
Wi
)
+
sizeof
(
WeiDataType
)
*
(
K
*
C
*
Y
*
X
)
+
sizeof
(
OutDataType
)
*
(
N
*
K
*
Ho
*
Wo
)
+
sizeof
(
OutDataType
)
*
(
K
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
conv_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_conv_name
=
conv_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
check_error
(
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_device_result
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in_n_c_hi_wi
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"wei: "
,
wei_k_c_y_x
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_host : "
,
out_n_k_ho_wo_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_device: "
,
out_n_k_ho_wo_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_conv_name
<<
std
::
endl
;
}
}
// namespace profiler
}
// namespace ck
profiler/profile_conv_fwd_bias_relu.cpp
0 → 100644
View file @
a84f254a
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp"
enum
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
};
enum
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
};
int
profile_conv_fwd_bias_relu
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
25
)
{
printf
(
"arg1: tensor operation (conv_fwd_bias_relu: ForwardConvolution+Bias+ReLu)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16)
\n
"
);
printf
(
"arg3: input tensor layout (0: NCHW; 1: NHWC)
\n
"
);
printf
(
"arg4: weight tensor layout (0: KCYX; 1: KYXC)
\n
"
);
printf
(
"arg5: output tensor layout (0: NKHW; 1: NHWK)
\n
"
);
printf
(
"arg6: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg7: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg8: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg9: run kernel # of times (>1)
\n
"
);
printf
(
"arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
1
);
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
6
]);
const
int
init_method
=
std
::
stoi
(
argv
[
7
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
8
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
12
]);
const
ck
::
index_t
Y
=
std
::
stoi
(
argv
[
13
]);
const
ck
::
index_t
X
=
std
::
stoi
(
argv
[
14
]);
const
ck
::
index_t
Hi
=
std
::
stoi
(
argv
[
15
]);
const
ck
::
index_t
Wi
=
std
::
stoi
(
argv
[
16
]);
const
ck
::
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
17
]);
const
ck
::
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
18
]);
const
ck
::
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
19
]);
const
ck
::
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
20
]);
const
ck
::
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
21
]);
const
ck
::
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
22
]);
const
ck
::
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
23
]);
const
ck
::
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
24
]);
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
YEff
)
/
conv_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
1
;
if
(
data_type
==
ConvDataType
::
F16_F16_F16
&&
in_layout
==
ConvInputLayout
::
NHWC
&&
wei_layout
==
ConvWeightLayout
::
KYXC
&&
out_layout
==
ConvOutputLayout
::
NHWK
)
{
ck
::
profiler
::
profile_conv_fwd_bias_relu_impl
<
2
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
N
,
K
,
C
,
std
::
vector
<
ck
::
index_t
>
{
Hi
,
Wi
},
std
::
vector
<
ck
::
index_t
>
{
Y
,
X
},
std
::
vector
<
ck
::
index_t
>
{
Ho
,
Wo
},
std
::
vector
<
ck
::
index_t
>
{
conv_stride_h
,
conv_stride_w
},
std
::
vector
<
ck
::
index_t
>
{
conv_dilation_h
,
conv_dilation_w
},
std
::
vector
<
ck
::
index_t
>
{
in_left_pad_h
,
in_left_pad_w
},
std
::
vector
<
ck
::
index_t
>
{
in_right_pad_h
,
in_right_pad_w
});
}
else
{
throw
std
::
runtime_error
(
"wrong! data_type & layout for this operator is not implemented"
);
}
return
1
;
}
profiler/profiler.cpp
View file @
a84f254a
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
int
profile_gemm
(
int
,
char
*
[]);
int
profile_gemm
(
int
,
char
*
[]);
int
profile_conv_fwd
(
int
,
char
*
[]);
int
profile_conv_fwd
(
int
,
char
*
[]);
int
profile_conv_fwd_bias_relu
(
int
,
char
*
[]);
int
profile_conv_fwd_bias_relu_add
(
int
,
char
*
[]);
int
profile_conv_fwd_bias_relu_add
(
int
,
char
*
[]);
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
@@ -19,6 +20,10 @@ int main(int argc, char* argv[])
...
@@ -19,6 +20,10 @@ int main(int argc, char* argv[])
{
{
return
profile_conv_fwd
(
argc
,
argv
);
return
profile_conv_fwd
(
argc
,
argv
);
}
}
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu"
)
==
0
)
{
return
profile_conv_fwd_bias_relu
(
argc
,
argv
);
}
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu_add"
)
==
0
)
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu_add"
)
==
0
)
{
{
return
profile_conv_fwd_bias_relu_add
(
argc
,
argv
);
return
profile_conv_fwd_bias_relu_add
(
argc
,
argv
);
...
@@ -28,6 +33,7 @@ int main(int argc, char* argv[])
...
@@ -28,6 +33,7 @@ int main(int argc, char* argv[])
printf
(
printf
(
"arg1: tensor operation (gemm: GEMM;
\n
"
"arg1: tensor operation (gemm: GEMM;
\n
"
" conv_fwd: ForwardConvolution;
\n
"
" conv_fwd: ForwardConvolution;
\n
"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU)
\n
"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add)
\n
"
);
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add)
\n
"
);
return
0
;
return
0
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment