Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8a906f5f
"library/vscode:/vscode.git/clone" did not exist on "ddad386b299391700136a5e9d2b5e8b47d49aea8"
Commit
8a906f5f
authored
Mar 04, 2021
by
root
Browse files
add chwn_cyxk_khwn
parent
129cc2e3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1774 additions
and
1 deletion
+1774
-1
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn.hpp
...convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn.hpp
+1339
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn.hpp
...convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn.hpp
+421
-0
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+14
-1
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn.hpp
0 → 100644
View file @
8a906f5f
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_CHWN_CYXK_KHWN_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_CHWN_CYXK_KHWN_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = Y * X * C
template
<
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmABlockTransferDstScalarPerVector_GemmM
,
typename
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
,
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_pad
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_c_y_x_k_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_c_hi_wi_n_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_k_ho_wo_n_global_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_in_global
,
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
N
=
in_c_hi_wi_n_global_desc
.
GetLength
(
I3
);
const
auto
C
=
in_c_hi_wi_n_global_desc
.
GetLength
(
I0
);
const
auto
K
=
out_k_ho_wo_n_global_desc
.
GetLength
(
I0
);
const
auto
Hi
=
in_c_hi_wi_n_global_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_c_hi_wi_n_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_k_ho_wo_n_global_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_k_ho_wo_n_global_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
C
*
Y
*
X
,
K
)),
make_tuple
(
make_pass_through_transform
(
C
*
Y
*
X
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// input tensor
const
auto
in_c_hip_wip_n_global_desc
=
transform_dynamic_tensor_descriptor
(
in_c_hi_wi_n_global_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_c_y_ho_x_wo_n_global_desc
=
transform_dynamic_tensor_descriptor
(
in_c_hip_wip_n_global_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
in_c_y_ho_x_wo_n_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
Ho
,
Wo
,
N
))),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
Ho
*
Wo
*
N
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Ho
*
Wo
*
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
auto
GemmM1
=
Number
<
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
>
{};
constexpr
auto
GemmN1
=
Number
<
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
>
{};
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
auto
GemmN0
=
GemmN
/
GemmN1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_k_m_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr
auto
b_k_n_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// GEMM
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v1
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmBBlockTransferSrcScalarPerVector_GemmK
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
1
,
GemmCThreadTransferDstScalarPerVector_GemmM1
,
decltype
(
a_k_m_global_iterator_hacks
),
decltype
(
b_k_n_global_iterator_hacks
),
decltype
(
c_m0_m1_n0_n1_global_tensor_iterator_hacks
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
const
auto
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
printf
(
"%s: BlockSize %d, GridSize %d
\n
"
,
__func__
,
BlockSize
,
GridSize
);
#if 1 // pass tensor descriptors by their reference
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif 1 // pass tensor descriptors by their pointers
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif 1 // pass tensor descriptor by void*
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#endif
}
};
#if 0
template <index_t BlockSize,
typename Float,
typename AccFloat,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmK,
index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmM1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_1x1
{
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_c_y_x_k_global_desc,
const DynamicTensorDescriptor<In...>& in_c_hi_wi_n_global_desc,
const DynamicTensorDescriptor<Out...>& out_k_ho_wo_n_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_c_hi_wi_n_global_desc.GetLength(I3);
const auto C = in_c_hi_wi_n_global_desc.GetLength(I0);
const auto K = out_k_ho_wo_n_global_desc.GetLength(I0);
const auto Hi = in_c_hi_wi_n_global_desc.GetLength(I1);
const auto Wi = in_c_hi_wi_n_global_desc.GetLength(I2);
const auto Ho = out_k_ho_wo_n_global_desc.GetLength(I1);
const auto Wo = out_k_ho_wo_n_global_desc.GetLength(I2);
const auto Y = wei_c_y_x_k_global_desc.GetLength(I1);
const auto X = wei_c_y_x_k_global_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0))
{
throw std::runtime_error("wrong! 1x1, stride 1, no padding");
}
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto GemmM1 = Number<GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster>{};
constexpr auto GemmN1 = Number<GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster>{};
const auto GemmM0 = GemmM / GemmM1;
const auto GemmN0 = GemmN / GemmN1;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)),
make_unmerge_transform(make_tuple(GemmN0, GemmN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr auto b_k_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1<
BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize);
#if 1 // pass tensor descriptors by their reference
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif 1 // pass tensor descriptors by their pointers
using ADesc = decltype(wei_gemmk_gemmm_global_desc);
using BDesc = decltype(in_gemmk_gemmn_global_desc);
using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc));
DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc));
DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc));
wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc);
in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice(
&out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc);
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc)*,
const Float*,
decltype(in_gemmk_gemmn_global_desc)*,
const Float*,
decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#elif
1
// pass tensor descriptor by void*
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#endif
}
}
;
#endif
}
// namespace ck
#endif
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn.hpp
0 → 100644
View file @
8a906f5f
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
std
::
cout
<<
"device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn"
<<
std
::
endl
;
using
namespace
ck
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
N
=
OutDesc
::
GetLengths
()[
I0
];
constexpr
auto
K
=
OutDesc
::
GetLengths
()[
I1
];
constexpr
auto
C
=
WeiDesc
::
GetLengths
()[
I1
];
constexpr
auto
Hi
=
InDesc
::
GetLengths
()[
I2
];
constexpr
auto
Wi
=
InDesc
::
GetLengths
()[
I3
];
constexpr
auto
Ho
=
OutDesc
::
GetLengths
()[
I2
];
constexpr
auto
Wo
=
OutDesc
::
GetLengths
()[
I3
];
constexpr
auto
Y
=
WeiDesc
::
GetLengths
()[
I2
];
constexpr
auto
X
=
WeiDesc
::
GetLengths
()[
I3
];
#if 0
// run-time variables
constexpr auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C));
constexpr auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C));
constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K));
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{});
const auto in_left_pads = to_multi_index(InLeftPads{});
const auto in_right_pads = to_multi_index(InRightPads{});
#else
// compile-time variables
constexpr
auto
in_c_hi_wi_n_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
C
,
Hi
,
Wi
,
N
));
constexpr
auto
wei_c_y_x_k_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
C
,
Y
,
X
,
K
));
constexpr
auto
out_k_ho_wo_n_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
Ho
,
Wo
,
N
));
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
Tensor
<
float
>
in_chwn
(
make_HostTensorDescriptor
(
make_native_tensor_descriptor_packed
(
Sequence
<
C
,
Hi
,
Wi
,
N
>
{})));
Tensor
<
float
>
wei_cyxk
(
make_HostTensorDescriptor
(
make_native_tensor_descriptor_packed
(
Sequence
<
C
,
Y
,
X
,
K
>
{})));
Tensor
<
float
>
out_khwn
(
make_HostTensorDescriptor
(
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
Ho
,
Wo
,
N
>
{})));
auto
f_nchw2chwn
=
[
&
](
auto
c
,
auto
hi
,
auto
wi
,
auto
n
)
{
in_chwn
(
c
,
hi
,
wi
,
n
)
=
in_nchw
(
n
,
c
,
hi
,
wi
);
};
auto
f_kcyx2cyxk
=
[
&
](
auto
c
,
auto
y
,
auto
x
,
auto
k
)
{
wei_cyxk
(
c
,
y
,
x
,
k
)
=
wei_kcyx
(
k
,
c
,
y
,
x
);
};
auto
f_nkhw2khwn
=
[
&
](
auto
k
,
auto
ho
,
auto
wo
,
auto
n
)
{
out_khwn
(
k
,
ho
,
wo
,
n
)
=
out_nkhw
(
n
,
k
,
ho
,
wo
);
};
make_ParallelTensorFunctor
(
f_nchw2chwn
,
C
,
Hi
,
Wi
,
N
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_kcyx2cyxk
,
C
,
Y
,
X
,
K
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_nkhw2khwn
,
K
,
Ho
,
Wo
,
N
)(
std
::
thread
::
hardware_concurrency
());
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_chwn_device_buf
(
data_sz
*
in_chwn
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_cyxk_device_buf
(
data_sz
*
wei_cyxk
.
mDesc
.
GetElementSpace
());
DeviceMem
out_khwn_device_buf
(
data_sz
*
out_khwn
.
mDesc
.
GetElementSpace
());
in_chwn_device_buf
.
ToDevice
(
in_chwn
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if 1
// cdata = 16, BlockSize = 64, 16x64x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
2
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
1
;
#elif 1
// cdata = 32, BlockSize = 64, 16x128x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
2
;
#elif 0
// cdata = 64, BlockSize = 64, 16x256x2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 0
// cdata = 64, BlockSize = 64, 16x256x4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x4
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
8
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmM1
=
4
;
#endif
constexpr
auto
conv_driver
=
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_pad
#elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_no_pad
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_1x1
#endif
<
BlockSize
,
TDevice
,
TDevice
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmBBlockTransferSrcScalarPerVector_GemmK
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmCThreadTransferDstScalarPerVector_GemmM1
>
{};
conv_driver
.
Run
(
wei_c_y_x_k_desc
,
in_c_hi_wi_n_desc
,
out_k_ho_wo_n_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
TDevice
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
out_khwn_device_buf
.
FromDevice
(
out_khwn
.
mData
.
data
());
auto
f_khwn2nkhw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_khwn
(
k
,
ho
,
wo
,
n
);
};
make_ParallelTensorFunctor
(
f_khwn2nkhw
,
N
,
K
,
Ho
,
Wo
)(
std
::
thread
::
hardware_concurrency
());
}
driver/src/conv_driver.cpp
View file @
8a906f5f
...
...
@@ -16,6 +16,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -701,7 +702,7 @@ int main(int argc, char* argv[])
LeftPads
{},
RightPads
{},
nrepeat
);
#elif
1
#elif
0
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
...
...
@@ -713,6 +714,18 @@ int main(int argc, char* argv[])
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment