Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b3a012bc
Commit
b3a012bc
authored
Mar 12, 2021
by
root
Browse files
thread mapping
parent
2662f8e5
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
232 additions
and
787 deletions
+232
-787
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+54
-379
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+161
-125
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+1
-5
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+3
-3
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+6
-268
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+6
-6
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
b3a012bc
...
...
@@ -111,24 +111,28 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_
gemmn
_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_gemmk_
n_ho_wo
_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Ho
),
make_pass_through_transform
(
Wo
)),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// output tensor
const
auto
out_gemmm_
gemmn
_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
const
auto
out_gemmm_
n_ho_wo
_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
,
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Ho
),
make_pass_through_transform
(
Wo
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
)
;
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
)
;
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
)
;
const
auto
GemmM
=
K
;
const
auto
GemmN
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
C
*
Y
*
X
;
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
...
...
@@ -136,20 +140,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
auto
GemmM1
=
Number
<
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
>
{};
constexpr
auto
GemmN1
=
Number
<
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
>
{};
const
auto
GemmM0
=
GemmM
/
GemmM1
;
const
auto
GemmN0
=
GemmN
/
GemmN1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_k_m_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
...
...
@@ -157,28 +147,32 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over b_k_n_global tensor
constexpr
auto
b_k_n_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr
auto
c_
m0_m1_n0_n1
_global_tensor_iterator_hacks
=
constexpr
auto
c_
k_n_h_w
_global_tensor_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
#if 1
// GEMM
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v2
<
BlockSize
,
...
...
@@ -186,8 +180,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
AccFloat
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_
gemmn
_global_desc
),
decltype
(
out_gemmm
0_gemmm1_gemmn0_gemmn1
_global_desc
),
decltype
(
in_gemmk_
n_ho_wo
_global_desc
),
decltype
(
out_gemmm
_n_ho_wo
_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
...
...
@@ -208,19 +202,19 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
Sequence
<
3
,
2
,
1
,
0
>
,
Sequence
<
3
,
2
,
1
,
0
>
,
3
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
Sequence
<
3
,
2
,
1
,
0
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
,
decltype
(
a_k_m_global_iterator_hacks
),
decltype
(
b_k_n_global_iterator_hacks
),
decltype
(
c_
m0_m1_n0_n1
_global_tensor_iterator_hacks
),
decltype
(
c_
k_n_h_w
_global_tensor_iterator_hacks
),
decltype
(
a_k_m_global_move_slice_window_iterator_hack
),
decltype
(
b_k_n_global_move_slice_window_iterator_hack
)
>
;
...
...
@@ -230,7 +224,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
#if 1 // pass tensor descriptors by their reference
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
...
...
@@ -248,10 +241,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_
gemmn
_global_desc
),
decltype
(
in_gemmk_
n_ho_wo
_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
...
...
@@ -263,9 +255,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_
gemmn
_global_desc
,
in_gemmk_
n_ho_wo
_global_desc
,
p_in_global
,
out_gemmm
0_gemmm1_gemmn0_gemmn1
_global_desc
,
out_gemmm
_n_ho_wo
_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
...
...
@@ -276,10 +268,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_
gemmn
_global_desc
),
decltype
(
in_gemmk_
n_ho_wo
_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
...
...
@@ -291,9 +282,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_
gemmn
_global_desc
,
in_gemmk_
n_ho_wo
_global_desc
,
p_in_global
,
out_gemmm
0_gemmm1_gemmn0_gemmn1
_global_desc
,
out_gemmm
_n_ho_wo
_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
...
...
@@ -304,10 +295,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_
gemmn
_global_desc
),
decltype
(
in_gemmk_
n_ho_wo
_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
...
...
@@ -319,9 +309,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_
gemmn
_global_desc
,
in_gemmk_
n_ho_wo
_global_desc
,
p_in_global
,
out_gemmm
0_gemmm1_gemmn0_gemmn1
_global_desc
,
out_gemmm
_n_ho_wo
_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
...
...
@@ -332,10 +322,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_
gemmn
_global_desc
),
decltype
(
in_gemmk_
n_ho_wo
_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
decltype
(
out_gemmm_n_ho_wo_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
...
...
@@ -347,323 +336,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
wei_k_c_y_x_global_desc
,
out_n_k_ho_wo_global_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif 1 // pass tensor descriptors by their pointers
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
)
*
,
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
)
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
)
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
reinterpret_cast
<
const
ADesc
*>
(
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
()),
p_wei_global
,
reinterpret_cast
<
const
BDesc
*>
(
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
()),
p_in_global
,
reinterpret_cast
<
const
CDesc
*>
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
()),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
in_n_c_hi_wi_global_desc
,
wei_k_c_y_x_global_desc
,
out_n_k_ho_wo_global_desc
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
#elif 1 // pass tensor descriptor by void*
using
ADesc
=
decltype
(
wei_gemmk_gemmm_global_desc
);
using
BDesc
=
decltype
(
in_gemmk_gemmn_global_desc
);
using
CDesc
=
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
DeviceMem
wei_gemmk_gemmm_global_desc_device_buf
(
sizeof
(
ADesc
));
DeviceMem
in_gemmk_gemmn_global_desc_device_buf
(
sizeof
(
BDesc
));
DeviceMem
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
(
sizeof
(
CDesc
));
wei_gemmk_gemmm_global_desc_device_buf
.
ToDevice
(
&
wei_gemmk_gemmm_global_desc
);
in_gemmk_gemmn_global_desc_device_buf
.
ToDevice
(
&
in_gemmk_gemmn_global_desc
);
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
ToDevice
(
&
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
);
index_t
nrepeat
=
100
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
void
*
,
const
Float
*
,
const
void
*
,
const
Float
*
,
const
void
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc_device_buf
.
GetDeviceBuffer
(),
p_wei_global
,
in_gemmk_gemmn_global_desc_device_buf
.
GetDeviceBuffer
(),
in_gemmk_n_ho_wo_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.
GetDeviceBuffer
(),
out_gemmm_n_ho_wo_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
b3a012bc
...
...
@@ -68,15 +68,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_
k_n
_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
constexpr
auto
b_
cyx_n_h_w
_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
1
>
{},
Number
<
8
>
{},
Number
<
8
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_
k_n
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_
cyx_n_h_w
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
}
...
...
@@ -84,9 +84,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_
k_n
_global_desc
,
const
BGlobalDesc
&
b_
cyx_n_h_w
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_
m0_m1_n0_n1
_global_desc
,
const
CGlobalDesc
&
c_
k_n_h_w
_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
...
...
@@ -97,7 +97,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_
k_n
_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_
cyx_n_h_w
_global_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
#if 0
...
...
@@ -118,7 +118,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
#endif
const
index_t
m_block_data_on_global
=
m_block_work_id
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
n_block_work_id
*
NPerBlock
;
const
index_t
h_block_data_on_global
=
n_block_work_id
*
8
;
const
index_t
w_block_data_on_global
=
n_block_work_id
*
8
;
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
...
...
@@ -133,8 +135,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_
k_n
_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
max_lds_align
);
constexpr
auto
b_
cyx_n_h_w
_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
1
>
{},
Number
<
8
>
{},
Number
<
8
>
{}),
max_lds_align
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
...
...
@@ -166,33 +168,52 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
make_multi_index
(
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
auto
b_blockwise_copy
=
BlockwiseDynamicTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperation
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
>
,
BBlockTransfer
ThreadSliceLengths_K_N
,
BBlockTransfer
ThreadClusterLengths_K_N
,
BBlockTransfer
ThreadClusterArrangeOrder
,
Sequence
<
KPerBlock
,
1
,
8
,
8
>
,
// BlockSliceLengths
Sequence
<
KPerBlock
,
1
,
1
,
1
>
,
//
ThreadSliceLengths_K_N
Sequence
<
1
,
1
,
8
,
8
>
,
//
ThreadClusterLengths_K_N
Sequence
<
3
,
2
,
0
,
1
>
,
//
ThreadClusterArrangeOrder
Float
,
Float
,
decltype
(
b_
k_n
_global_desc
),
decltype
(
b_
k_n
_block_desc
),
BBlockTransferSrc
AccessOrder
,
Sequence
<
0
,
1
>
,
BBlockTransfer
SrcVectorDim
,
1
,
BBlockTransfer
SrcScalarPerVector
,
BBlockTransfer
DstScalarPerVector
_N
,
decltype
(
b_
cyx_n_h_w
_global_desc
),
// SrcDesc
decltype
(
b_
cyx_n_h_w
_block_desc
),
// DstDesc
Sequence
<
3
,
2
,
0
,
1
>
,
// SrcDim
AccessOrder
Sequence
<
3
,
2
,
0
,
1
>
,
// DstDimAccessOrder
3
,
//
SrcVectorDim
3
,
// DstVectorDim
1
,
//
SrcScalarPerVector
1
,
//
DstScalarPerVector
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_k_n_global_desc
,
make_multi_index
(
0
,
n_block_data_on_global
),
b_k_n_block_desc
,
make_multi_index
(
0
,
0
));
true
>
(
b_cyx_n_h_w_global_desc
,
make_multi_index
(
0
,
0
,
h_block_data_on_global
,
w_block_data_on_global
),
b_cyx_n_h_w_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
#if 0
constexpr auto b_cyx_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<NPerThread>{}));
using BThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v2<Float,
Float,
decltype(b_cyx_n_h_w_global_desc),
decltype(b_cyx_n_h_w_thread_desc),
Sequence<KPerThread, NPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>;
#endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -201,23 +222,24 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
//
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
//
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
//
"wrong!");
constexpr
index_t
MRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
NRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
//
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
//
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_
m0m1_n0n1
_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{}));
constexpr
auto
c_
k_n_h_w
_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
1
>
{}));
#if 0
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_k_m_block_desc),
decltype
(
b_
k_n
_block_desc
),
decltype
(
c_
m0m1_n0n1
_thread_desc
),
decltype(b_
cyx_n_h_w
_block_desc),
decltype(c_
k_n_h_w
_thread_desc),
MPerThread,
NPerThread,
KPerThread,
...
...
@@ -225,46 +247,52 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread
,
NPerThread
>
{};
1,
1>{};
#endif
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_
k_n
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_
cyx_n_h_w
_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
Float
*
p_a_block_double
=
p_shared_block
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_desc
.
GetElementSpaceSize
()];
AccFloat
p_c_thread
[
c_k_n_h_w_thread_desc
.
GetElementSpaceSize
()];
for
(
index_t
i
=
0
;
i
<
c_k_n_h_w_thread_desc
.
GetElementSpaceSize
();
i
++
)
{
p_c_thread
[
i
]
=
0
;
}
// zero out threadwise output
threadwise_matrix_set_zero_v2
(
c_
m0m1_n0n1
_thread_desc
,
p_c_thread
);
//
threadwise_matrix_set_zero_v2(c_
k_n_h_w
_thread_desc, p_c_thread);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k_m_global_iterator_hacks
=
AGlobalIteratorHacks
{};
constexpr
auto
b_
k_n
_global_iterator_hacks
=
BGlobalIteratorHacks
{};
constexpr
auto
b_
cyx_n_h_w
_global_iterator_hacks
=
BGlobalIteratorHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr
auto
a_k_m_global_move_slice_window_iterator_hack
=
AGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_
k_n
_global_move_slice_window_iterator_hack
=
constexpr
auto
b_
cyx_n_h_w
_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_
k_n
_global_desc
,
p_b_global
,
b_
k_n
_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_
cyx_n_h_w
_global_desc
,
p_b_global
,
b_
cyx_n_h_w
_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
);
b_blockwise_copy
.
RunWrite
(
b_
k_n
_block_desc
,
p_b_block_double
);
b_blockwise_copy
.
RunWrite
(
b_
cyx_n_h_w
_block_desc
,
p_b_block_double
);
}
if
constexpr
(
HasMainKBlockLoop
)
...
...
@@ -285,9 +313,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
// b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
// b_block_slice_copy_step,
// b_cyx_n_h_w_global_move_slice_window_iterator_hack);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_cyx_n_h_w_global_desc
,
b_block_slice_copy_step
,
b_
k_n
_global_move_slice_window_iterator_hack
);
b_
cyx_n_h_w
_global_move_slice_window_iterator_hack
);
__syncthreads
();
...
...
@@ -295,22 +328,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_
k_n
_global_desc
,
p_b_global
,
b_
k_n
_global_iterator_hacks
);
b_
cyx_n_h_w
_global_desc
,
p_b_global
,
b_
cyx_n_h_w
_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_c_thread
);
//
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
b_blockwise_copy
.
RunWrite
(
b_
k_n
_block_desc
,
p_b_block_odd
);
b_blockwise_copy
.
RunWrite
(
b_
cyx_n_h_w
_block_desc
,
p_b_block_odd
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_
k_n
_global_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_
cyx_n_h_w
_global_desc
,
b_block_slice_copy_step
,
b_
k_n
_global_move_slice_window_iterator_hack
);
b_
cyx_n_h_w
_global_move_slice_window_iterator_hack
);
__syncthreads
();
...
...
@@ -318,14 +351,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_
k_n
_global_desc
,
p_b_global
,
b_
k_n
_global_iterator_hacks
);
b_
cyx_n_h_w
_global_desc
,
p_b_global
,
b_
cyx_n_h_w
_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_c_thread
);
//
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
b_blockwise_copy
.
RunWrite
(
b_
k_n
_block_desc
,
p_b_block_even
);
b_blockwise_copy
.
RunWrite
(
b_
cyx_n_h_w
_block_desc
,
p_b_block_even
);
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
...
...
@@ -337,53 +370,49 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
,
a_k_m_global_move_slice_window_iterator_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_
k_n
_global_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_
cyx_n_h_w
_global_desc
,
b_block_slice_copy_step
,
b_
k_n
_global_move_slice_window_iterator_hack
);
b_
cyx_n_h_w
_global_move_slice_window_iterator_hack
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_
k_n
_global_desc
,
p_b_global
,
b_
k_n
_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_
cyx_n_h_w
_global_desc
,
p_b_global
,
b_
cyx_n_h_w
_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
//
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
b_blockwise_copy
.
RunWrite
(
b_
k_n
_block_desc
,
p_b_block_double
+
b_block_space_size
);
b_blockwise_copy
.
RunWrite
(
b_
cyx_n_h_w
_block_desc
,
p_b_block_double
+
b_block_space_size
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_c_thread
);
//
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
//
p_b_block_double + b_block_space_size,
//
p_c_thread);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
//
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
#if 1
// output: register to global memory
{
constexpr
auto
M1
=
Number
<
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
>
{};
constexpr
auto
N1
=
Number
<
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
>
{};
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerThread
>
{},
Number
<
NRepeat
>
{},
Number
<
NPerThread
>
{}));
constexpr
auto
c_k_n_h_w_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
1
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
#if 0
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
...
...
@@ -392,47 +421,54 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
#endif
const
index_t
h_thread_id
=
get_thread_local_1d_id
()
/
8
;
const
index_t
w_thread_id
=
get_thread_local_1d_id
()
%
8
;
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
;
const
index_t
h_thread_data_on_global
=
h_block_data_on_global
+
h_thread_id
;
const
index_t
w_thread_data_on_global
=
w_block_data_on_global
+
w_thread_id
;
// hack to control index calculation when iterating over c_
m0_m1_n0_n1
_global tensor
constexpr
auto
c_
m0_m1_n0_n1
_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
// hack to control index calculation when iterating over c_
k_n_h_w
_global tensor
constexpr
auto
c_
k_n_h_w
_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
constexpr
auto
tmp
=
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerThread
>
{},
Number
<
NRepeat
>
{},
Number
<
NPerThread
>
{}));
//
constexpr auto tmp = make_unmerge_transform(make_tuple(
//
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
AccFloat
,
Float
,
decltype
(
c_
m0_m1_n0_n1
_thread_desc
),
decltype
(
c_
m0_m1_n0_n1
_global_desc
),
Sequence
<
MRepeat
,
MPerThread
,
NRepeat
,
NPerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
decltype
(
c_
k_n_h_w
_thread_desc
),
decltype
(
c_
k_n_h_w
_global_desc
),
Sequence
<
MPerThread
,
1
,
1
,
1
>
,
Sequence
<
3
,
2
,
0
,
1
>
,
//
CThreadTransferSrcDstAccessOrder
3
,
//
CThreadTransferSrcDstVectorDim
1
,
//
CThreadTransferDstScalarPerVector,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
1
,
true
>
(
c_m0_m1_n0_n1_global_desc
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
))
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
true
>
(
c_k_n_h_w_global_desc
,
make_multi_index
(
m_thread_data_on_global
,
0
,
h_thread_data_on_global
,
w_thread_data_on_global
))
.
Run
(
c_k_n_h_w_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_c_thread
,
c_
m0_m1_n0_n1
_global_desc
,
c_
k_n_h_w
_global_desc
,
p_c_global
,
c_
m0_m1_n0_n1
_global_tensor_iterator_hacks
);
c_
k_n_h_w
_global_tensor_iterator_hacks
);
}
#endif
}
// pass tensor descriptor by reference
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_
k_n
_global_desc
,
const
BGlobalDesc
&
b_
cyx_n_h_w
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_
m0_m1_n0_n1
_global_desc
,
const
CGlobalDesc
&
c_
k_n_h_w
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
...
...
@@ -443,9 +479,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Run
(
a_k_m_global_desc
,
p_a_global
,
b_
k_n
_global_desc
,
b_
cyx_n_h_w
_global_desc
,
p_b_global
,
c_
m0_m1_n0_n1
_global_desc
,
c_
k_n_h_w
_global_desc
,
p_c_global
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
@@ -456,22 +492,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
*
p_a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
*
p_b_
k_n
_global_desc
,
const
BGlobalDesc
*
p_b_
cyx_n_h_w
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
*
p_c_
m0_m1_n0_n1
_global_desc
,
const
CGlobalDesc
*
p_c_
k_n_h_w
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_k_m_global_desc
=
*
p_a_k_m_global_desc
;
const
auto
b_
k_n
_global_desc
=
*
p_b_
k_n
_global_desc
;
const
auto
c_
m0_m1_n0_n1
_global_desc
=
*
p_c_
m0_m1_n0_n1
_global_desc
;
const
auto
b_
cyx_n_h_w
_global_desc
=
*
p_b_
cyx_n_h_w
_global_desc
;
const
auto
c_
k_n_h_w
_global_desc
=
*
p_c_
k_n_h_w
_global_desc
;
Run
(
a_k_m_global_desc
,
p_a_global
,
b_
k_n
_global_desc
,
b_
cyx_n_h_w
_global_desc
,
p_b_global
,
c_
m0_m1_n0_n1
_global_desc
,
c_
k_n_h_w
_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
...
@@ -481,23 +517,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
void
*
p_a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
void
*
p_b_
k_n
_global_desc
,
const
void
*
p_b_
cyx_n_h_w
_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
void
*
p_c_
m0_m1_n0_n1
_global_desc
,
const
void
*
p_c_
k_n_h_w
_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_k_m_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
(
p_a_k_m_global_desc
);
const
auto
b_
k_n
_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_
k_n
_global_desc
);
const
auto
c_
m0_m1_n0_n1
_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_
m0_m1_n0_n1
_global_desc
);
const
auto
b_
cyx_n_h_w
_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_
cyx_n_h_w
_global_desc
);
const
auto
c_
k_n_h_w
_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_
k_n_h_w
_global_desc
);
Run
(
a_k_m_global_desc
,
p_a_global
,
b_
k_n
_global_desc
,
b_
cyx_n_h_w
_global_desc
,
p_b_global
,
c_
m0_m1_n0_n1
_global_desc
,
c_
k_n_h_w
_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
b3a012bc
...
...
@@ -136,7 +136,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// loop over tensor and copy
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep
;
...
...
@@ -463,7 +462,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// loop over tensor and copy
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep
;
...
...
@@ -500,7 +498,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}();
// copy data
static_assert
(
DstAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! hardcode for
ds_read
"
);
static_assert
(
DstAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! hardcode for
vgpr dst
"
);
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
...
...
@@ -798,7 +796,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// loop over tensor and copy
static_ford
<
decltype
(
ordered_src_access_lengths
)
>
{}([
&
](
auto
ordered_src_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep
;
...
...
@@ -978,7 +975,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// loop over tensor and copy
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep
;
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
b3a012bc
...
...
@@ -75,9 +75,9 @@ struct ThreadwiseGemm_km_kn_mn_v1
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
CDesc
{}
[
I0
]
;
constexpr
auto
N
=
CDesc
{}
[
I1
]
;
constexpr
auto
K
=
ADesc
{}
[
I0
]
;
constexpr
auto
M
=
CDesc
{}
.
GetLength
(
I0
)
;
constexpr
auto
N
=
CDesc
{}
.
GetLength
(
I1
)
;
constexpr
auto
K
=
ADesc
{}
.
GetLength
(
I0
)
;
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
b3a012bc
...
...
@@ -37,7 +37,7 @@
#endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
1
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
0
#endif
#ifndef CK_USE_AMD_V_FMAC_F32
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
b3a012bc
...
...
@@ -67,7 +67,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#if 1
// cdata = 16, BlockSize = 64, 16x64x4
constexpr
index_t
BlockSize
=
64
;
...
...
@@ -75,17 +74,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
2
;
constexpr
index_t
GemmMPerThread
=
16
;
constexpr
index_t
GemmNPerThread
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
1
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
64
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
...
...
@@ -99,265 +95,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
2
;
#elif 0
// cdata = 16, BlockSize = 64, 16x64x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
2
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
2
;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
16
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
2
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
16
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
16
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
8
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
1
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 0
// cdata = 64, BlockSize = 256, 128x128x4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 2x2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#endif
constexpr
auto
conv_driver
=
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
<
...
...
driver/src/conv_driver.cpp
View file @
b3a012bc
...
...
@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
#elif 1
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
1080
;
constexpr
index_t
WI
=
1920
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
16
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
...
...
@@ -657,7 +657,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
#if
0
#if
1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 0
...
...
@@ -776,9 +776,9 @@ int main(int argc, char* argv[])
}
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#if
0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
#if
1
//
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
//
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange
(
std
::
cout
<<
"out_nkhw_host : "
,
out_nkhw_host
.
mData
,
","
)
<<
std
::
endl
;
LogRange
(
std
::
cout
<<
"out_nkhw_device: "
,
out_nkhw_device
.
mData
,
","
)
<<
std
::
endl
;
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment