Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d8c89b68
Commit
d8c89b68
authored
May 10, 2021
by
Chao Liu
Browse files
refactor driver for conv
parent
fd160c63
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1360 additions
and
2643 deletions
+1360
-2643
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+419
-1525
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+281
-1014
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
+396
-0
composable_kernel/include/tensor_description/tensor_adaptor.hpp
...able_kernel/include/tensor_description/tensor_adaptor.hpp
+21
-0
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+37
-8
driver/include/device.hpp
driver/include/device.hpp
+22
-4
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+94
-44
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+87
-45
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+3
-3
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
d8c89b68
This diff is collapsed.
Click to expand it.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
d8c89b68
This diff is collapsed.
Click to expand it.
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
0 → 100644
View file @
d8c89b68
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_description/tensor_adaptor.hpp
View file @
d8c89b68
...
@@ -184,6 +184,27 @@ struct TensorAdaptor
...
@@ -184,6 +184,27 @@ struct TensorAdaptor
return
get_container_subset
(
idx_hidden
,
BottomDimensionHiddenIds
{});
return
get_container_subset
(
idx_hidden
,
BottomDimensionHiddenIds
{});
}
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"TensorAdaptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
printf
(
"LowerDimensionHiddenIds:"
);
LowerDimensionHiddenIdss
{}.
At
(
i
).
Print
();
printf
(
"UpperDimensionHiddenIds:"
);
UpperDimensionHiddenIdss
{}.
At
(
i
).
Print
();
});
printf
(
"BottomDimensionHiddenIds:"
);
BottomDimensionHiddenIds
::
Print
();
printf
(
"TopDimensionHiddenIds:"
);
TopDimensionHiddenIds
::
Print
();
printf
(
"}"
);
}
private:
private:
Transforms
transforms_
;
Transforms
transforms_
;
ElementSize
element_size_
;
ElementSize
element_size_
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
d8c89b68
...
@@ -12,7 +12,36 @@
...
@@ -12,7 +12,36 @@
namespace
ck
{
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
kernel_dynamic_gemm_v1
(
const
AGlobalDesc
a_k_m_global_desc
,
const
FloatA
*
__restrict__
p_a_global
,
const
BGlobalDesc
b_k_n_global_desc
,
const
FloatB
*
__restrict__
p_b_global
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
const
CBlockClusterDesc
c_block_cluster_desc
)
{
GridwiseGemm
{}.
Run
(
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
// non-modifiable parameter address space, so compiler can enable corresponding optimization
...
@@ -26,13 +55,13 @@ template <typename GridwiseGemm,
...
@@ -26,13 +55,13 @@ template <typename GridwiseGemm,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
run_gridwise
_dynamic_gemm_v1
(
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
__global__
void
kernel
_dynamic_gemm_v1
(
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
FloatA
*
__restrict__
p_a_global
,
const
FloatA
*
__restrict__
p_a_global
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
FloatB
*
__restrict__
p_b_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
{
// first cast void __CONSTANT__ void* to void*
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
...
...
driver/include/device.hpp
View file @
d8c89b68
...
@@ -46,6 +46,7 @@ void launch_kernel(F kernel,
...
@@ -46,6 +46,7 @@ void launch_kernel(F kernel,
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
F
kernel
,
float
launch_and_time_kernel
(
F
kernel
,
int
nrepeat
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
std
::
size_t
lds_byte
,
...
@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel,
...
@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel,
{
{
KernelTimer
timer
;
KernelTimer
timer
;
timer
.
Start
();
printf
(
"%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up
\n
"
);
// warm up
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
timer
.
End
();
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
timer
.
Start
();
hipGetLastError
();
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
}
return
timer
.
GetElapsedTime
();
timer
.
End
();
return
timer
.
GetElapsedTime
()
/
nrepeat
;
}
}
#elif CK_DEVICE_BACKEND_NVIDIA
#elif CK_DEVICE_BACKEND_NVIDIA
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
d8c89b68
...
@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
...
@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
{
{
using
namespace
ck
;
using
namespace
ck
;
std
::
cout
<<
"device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
std
::
cout
<<
__func__
<<
std
::
endl
;
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TInWei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
...
@@ -459,50 +468,91 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
...
@@ -459,50 +468,91 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#endif
#endif
constexpr
auto
conv_driver
=
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
auto
descs
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitG
emm_v4r4_nchw_kcyx_nkhw_pad
transform_forward_convolution_into_g
emm_v4r4_nchw_kcyx_nkhw_pad
#elif 0
#elif 0
DriverDynamicConvolutionForwardImplicitG
emm_v4r4_nchw_kcyx_nkhw_no_pad
transform_forward_convolution_into_g
emm_v4r4_nchw_kcyx_nkhw_no_pad
#el
if 1
#el
se
DriverDynamicConvolutionForwardImplicitG
emm_v4r4_nchw_kcyx_nkhw_1x1
transform_forward_convolution_into_g
emm_v4r4_nchw_kcyx_nkhw_1x1
#endif
#endif
<
BlockSize
,
<
GemmMPerBlock
,
GemmNPerBlock
,
GemmM1
,
GemmN1
>
(
wei_k_c_y_x_desc
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
in_n_c_hi_wi_desc
,
TAcc
,
out_n_k_ho_wo_desc
,
TOut
,
conv_strides
,
GemmMPerBlock
,
conv_dilations
,
GemmNPerBlock
,
in_left_pads
,
GemmKPerBlock
,
in_right_pads
);
GemmMPerThread
,
GemmNPerThread
,
float
ave_time
=
launch_kernel_dynamic_gemm_v1
<
GemmKPerThread
,
BlockSize
,
GemmMLevel0Cluster
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
GemmNLevel0Cluster
,
TAcc
,
GemmMLevel1Cluster
,
TOut
,
GemmNLevel1Cluster
,
InMemoryDataOperation
::
Set
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
decltype
(
descs
[
I0
]),
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
decltype
(
descs
[
I1
]),
GemmABlockTransferSrcScalarPerVector_GemmK
,
decltype
(
descs
[
I2
]),
GemmABlockTransferDstScalarPerVector_GemmM
,
decltype
(
descs
[
I3
]),
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmMPerBlock
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmNPerBlock
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmKPerBlock
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmMPerThread
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
{};
GemmNPerThread
,
GemmKPerThread
,
conv_driver
.
Run
(
wei_k_c_y_x_desc
,
GemmMLevel0Cluster
,
in_n_c_hi_wi_desc
,
GemmNLevel0Cluster
,
out_n_k_ho_wo_desc
,
GemmMLevel1Cluster
,
conv_strides
,
GemmNLevel1Cluster
,
conv_dilations
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
in_left_pads
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
in_right_pads
,
Sequence
<
1
,
0
>
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
Sequence
<
1
,
0
>
,
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
0
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
GemmABlockTransferSrcScalarPerVector_GemmK
,
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
GemmABlockTransferDstScalarPerVector_GemmM
,
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()));
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
descs
[
I4
]),
decltype
(
descs
[
I5
]),
decltype
(
descs
[
I6
]),
decltype
(
descs
[
I7
]),
decltype
(
descs
[
I8
])
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
descs
[
I3
],
descs
[
I4
],
descs
[
I5
],
descs
[
I6
],
descs
[
I7
],
descs
[
I8
],
nrepeat
);
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
// copy result back to host
out_n_k_ho_wo_device_buf
.
FromDevice
(
out_n_k_ho_wo
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
FromDevice
(
out_n_k_ho_wo
.
mData
.
data
());
}
}
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
d8c89b68
...
@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
...
@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
{
{
using
namespace
ck
;
using
namespace
ck
;
std
::
cout
<<
"device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk"
std
::
cout
<<
__func__
<<
std
::
endl
;
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
constexpr
auto
N
=
OutDesc
::
GetLengths
()[
I0
];
constexpr
auto
N
=
OutDesc
::
GetLengths
()[
I0
];
constexpr
auto
K
=
OutDesc
::
GetLengths
()[
I1
];
constexpr
auto
K
=
OutDesc
::
GetLengths
()[
I1
];
...
@@ -372,51 +376,89 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
...
@@ -372,51 +376,89 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#endif
#endif
constexpr
auto
conv_driver
=
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
auto
descs
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
#elif 0
#else
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif
#endif
<
BlockSize
,
<
GemmMPerBlock
,
GemmNPerBlock
,
GemmM1
,
GemmN1
>
(
wei_k_y_x_c0_desc
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
in_n_hi_wi_c0_desc
,
TAcc
,
out_n_ho_wo_k_desc
,
TOut
,
conv_strides
,
GemmMPerBlock
,
conv_dilations
,
GemmNPerBlock
,
in_left_pads
,
GemmKPerBlock
,
in_right_pads
);
GemmMPerThread
,
GemmNPerThread
,
float
ave_time
=
launch_kernel_dynamic_gemm_v1
<
GemmKPerThread
,
BlockSize
,
GemmMLevel0Cluster
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
GemmNLevel0Cluster
,
TAcc
,
GemmMLevel1Cluster
,
TOut
,
GemmNLevel1Cluster
,
InMemoryDataOperation
::
Set
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
decltype
(
descs
[
I0
]),
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
decltype
(
descs
[
I1
]),
GemmABlockTransferSrcScalarPerVector_GemmK
,
decltype
(
descs
[
I2
]),
GemmABlockTransferDstScalarPerVector_GemmM
,
decltype
(
descs
[
I3
]),
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmMPerBlock
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmNPerBlock
,
GemmBBlockTransferSrcScalarPerVector_GemmK
,
GemmKPerBlock
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmMPerThread
,
GemmCThreadTransferDstScalarPerVector_GemmM1
>
{};
GemmNPerThread
,
GemmKPerThread
,
conv_driver
.
Run
(
wei_k_y_x_c0_desc
,
GemmMLevel0Cluster
,
in_n_hi_wi_c0_desc
,
GemmNLevel0Cluster
,
out_n_ho_wo_k_desc
,
GemmMLevel1Cluster
,
conv_strides
,
GemmNLevel1Cluster
,
conv_dilations
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
in_left_pads
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
in_right_pads
,
Sequence
<
1
,
0
>
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
Sequence
<
1
,
0
>
,
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
0
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
GemmABlockTransferSrcScalarPerVector_GemmK
,
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
GemmABlockTransferDstScalarPerVector_GemmM
,
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()));
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmBBlockTransferSrcScalarPerVector_GemmK
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
1
,
GemmCThreadTransferDstScalarPerVector_GemmM1
,
decltype
(
descs
[
I4
]),
decltype
(
descs
[
I5
]),
decltype
(
descs
[
I6
]),
decltype
(
descs
[
I7
]),
decltype
(
descs
[
I8
])
>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
descs
[
I3
],
descs
[
I4
],
descs
[
I5
],
descs
[
I6
],
descs
[
I7
],
descs
[
I8
],
nrepeat
);
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
// copy result back to host
out_n_ho_wo_k_device_buf
.
FromDevice
(
out_n_ho_wo_k
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
FromDevice
(
out_n_ho_wo_k
.
mData
.
data
());
auto
f_nhwk2nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nhwk2nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
...
...
driver/src/conv_driver.cpp
View file @
d8c89b68
...
@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
...
@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
0
#elif
1
// 3x3, 71x71
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
C
=
192
;
...
@@ -225,7 +225,7 @@ int main(int argc, char* argv[])
...
@@ -225,7 +225,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
1
#elif
0
// 7x1, 17x17
// 7x1, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
...
@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif
0
#elif
1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
<
in_data_t
,
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
<
in_data_t
,
in_vector_size
,
in_vector_size
,
acc_data_t
,
acc_data_t
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment