Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c31867bd
Commit
c31867bd
authored
Feb 24, 2021
by
Chao Liu
Browse files
add non-padded
parent
5f0c56d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1040 additions
and
15 deletions
+1040
-15
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+654
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+113
-10
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+200
-2
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+73
-3
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
c31867bd
...
@@ -229,6 +229,660 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
...
@@ -229,6 +229,660 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
printf
(
"%s: BlockSize %d, GridSize %d
\n
"
,
__func__
,
BlockSize
,
GridSize
);
#if 1 // pass tensor descriptors by their reference
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif 1 // pass tensor descriptors by their pointers
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif 1 // pass tensor descriptor by void*
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#endif
}
};
template
<
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmABlockTransferDstScalarPerVector_GemmM
,
typename
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
,
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_hi_wi_c_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_in_global
,
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
in_n_hi_wi_c_global_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_global_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_global_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_global_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_global_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_global_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_global_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_global_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
if
(
!
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
))
{
throw
std
::
runtime_error
(
"wrong! 1x1, stride 1, no padding"
);
}
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_gemmk_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
C
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
auto
GemmM1
=
Number
<
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
>
{};
constexpr
auto
GemmN1
=
Number
<
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
>
{};
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
auto
GemmN0
=
GemmN
/
GemmN1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_k_m_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr
auto
b_k_n_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// GEMM
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v1
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmBBlockTransferSrcScalarPerVector_GemmK
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
1
,
GemmCThreadTransferDstScalarPerVector_GemmM1
,
decltype
(
a_k_m_global_iterator_hacks
),
decltype
(
b_k_n_global_iterator_hacks
),
decltype
(
c_m0_m1_n0_n1_global_tensor_iterator_hacks
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
printf
(
"%s: BlockSize %d, GridSize %d
\n
"
,
__func__
,
BlockSize
,
GridSize
);
#if 1 // pass tensor descriptors by their reference
#if 1 // pass tensor descriptors by their reference
index_t
nrepeat
=
100
;
index_t
nrepeat
=
100
;
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
c31867bd
...
@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -39,7 +39,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// run-time variables
// run-time variables
const auto in_n_c_hi_wi_desc =
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
...
@@ -67,7 +67,110 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -67,7 +67,110 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#endif
#if 0
#if 1
// cdata = 16, BlockSize = 64, 16x64x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
2
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
2
;
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
2
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
2
;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
// cdata = 64, BlockSize = 128, 32x256x8
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
...
@@ -75,14 +178,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -75,14 +178,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr index_t GemmMPerThread
= 4;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr index_t GemmNPerThread
= 4;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr index_t GemmKPerThread
= 1;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr index_t GemmMLevel0Cluster
= 2;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr index_t GemmNLevel0Cluster
= 2;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr index_t GemmMLevel1Cluster
= 2;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr index_t GemmNLevel1Cluster
= 16;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
...
@@ -259,7 +362,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -259,7 +362,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#elif
1
#elif
0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
c31867bd
...
@@ -48,7 +48,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
...
@@ -48,7 +48,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
constexpr
auto
Y
=
WeiDesc
::
GetLengths
()[
I2
];
constexpr
auto
Y
=
WeiDesc
::
GetLengths
()[
I2
];
constexpr
auto
X
=
WeiDesc
::
GetLengths
()[
I3
];
constexpr
auto
X
=
WeiDesc
::
GetLengths
()[
I3
];
#if
1
#if
0
// run-time variables
// run-time variables
constexpr auto in_n_hi_wi_c_desc =
constexpr auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C));
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C));
...
@@ -110,6 +110,204 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
...
@@ -110,6 +110,204 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
out_nhwk_device_buf
.
ToDevice
(
out_nhwk
.
mData
.
data
());
out_nhwk_device_buf
.
ToDevice
(
out_nhwk
.
mData
.
data
());
#if 0
#if 0
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 16;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
#elif
0
// cdata = 32, BlockSize = 64, 16x128x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
2
;
#elif 1
// cdata = 64, BlockSize = 64, 16x256x2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 0
// cdata = 64, BlockSize = 64, 16x256x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x4
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x8
// cdata = 64, BlockSize = 256, 128x128x8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -174,7 +372,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
...
@@ -174,7 +372,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
#elif
1
#elif
0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
#elif 1
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
...
...
driver/src/conv_driver.cpp
View file @
c31867bd
...
@@ -22,10 +22,80 @@ int main(int argc, char* argv[])
...
@@ -22,10 +22,80 @@ int main(int argc, char* argv[])
#if 0
#if 0
constexpr index_t N = 1;
constexpr index_t N = 1;
constexpr index_t C = 32;
constexpr index_t C = 4;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
540
;
constexpr
index_t
HI
=
540
;
constexpr
index_t
WI
=
960
;
constexpr
index_t
WI
=
960
;
constexpr index_t K = 32;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
270
;
constexpr
index_t
WI
=
480
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 1
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
2
;
constexpr
index_t
HI
=
1080
;
constexpr
index_t
WI
=
1920
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
540
;
constexpr
index_t
WI
=
960
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
270
;
constexpr
index_t
WI
=
480
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -616,7 +686,7 @@ int main(int argc, char* argv[])
...
@@ -616,7 +686,7 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment