Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fa479ce4
Commit
fa479ce4
authored
Jan 07, 2021
by
Chao Liu
Browse files
modify gridwise dynamic gemm looping
parent
8e35a579
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
531 additions
and
118 deletions
+531
-118
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+366
-2
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+76
-62
driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+5
-5
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+58
-23
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+26
-26
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
fa479ce4
...
@@ -211,9 +211,369 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -211,9 +211,369 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
is_even_number
_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
const
bool
has_main
_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
if
(
is_even_number_k_block_loop
)
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
};
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmABlockTransferDstScalarPerVector_GemmM
,
typename
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
,
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
,
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
MultiIndex
<
2
>
conv_strides
,
const
MultiIndex
<
2
>
conv_dilations
,
const
MultiIndex
<
2
>
in_left_pads
,
const
MultiIndex
<
2
>
in_right_pads
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_in_global
,
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
index_t
ConvStrideH
=
conv_strides
[
I0
];
const
index_t
ConvStrideW
=
conv_strides
[
I1
];
const
index_t
ConvDilationH
=
conv_dilations
[
I0
];
const
index_t
ConvDilationW
=
conv_dilations
[
I1
];
const
index_t
InLeftPadH
=
in_left_pads
[
I0
];
const
index_t
InLeftPadW
=
in_left_pads
[
I1
];
const
index_t
InRightPadH
=
in_right_pads
[
I0
];
const
index_t
InRightPadW
=
in_right_pads
[
I1
];
// weight tensor
#if 0
// TODO implement graph optimization of tensor descriptor transformation
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(DynamicPassThrough{K}, DynamicMerge<3>{make_multi_index(C, Y, X)}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
#else
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed
<
2
>
(
make_multi_index
(
K
,
C
*
Y
*
X
)),
make_tuple
(
DynamicPassThrough
{
K
},
DynamicPassThrough
{
C
*
Y
*
X
}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
#endif
// input tensor
// debug: don't do padding
const
auto
in_n_c_hip_wip_global_desc
=
in_n_c_hi_wi_global_desc
;
const
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I2
);
const
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLength
(
I3
);
const
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
DynamicPassThrough
{
N
},
DynamicPassThrough
{
C
},
DynamicEmbed
<
2
>
{
make_multi_index
(
Y
,
Ho
),
make_multi_index
(
ConvDilationH
,
ConvStrideH
)},
DynamicEmbed
<
2
>
{
make_multi_index
(
X
,
Wo
),
make_multi_index
(
ConvDilationW
,
ConvStrideW
)}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmk_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
DynamicMerge
<
3
>
{
make_multi_index
(
C
,
Y
,
X
)},
DynamicMerge
<
3
>
{
make_multi_index
(
N
,
Ho
,
Wo
)}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
#if 0
//TODO: implement graph optimization of tensor descriptor transformation
const auto out_gemmm_gemmn_global_desc =
transform_dynamic_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(DynamicPassThrough{K}, DynamicMerge<3>{make_mult_index(N, Ho, Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#else
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed
<
3
>
(
make_multi_index
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
DynamicPassThrough
{
K
},
DynamicMerge
<
2
>
{
make_multi_index
(
N
,
Ho
*
Wo
)}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#endif
const
index_t
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
index_t
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
index_t
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
index_t
GemmM0
=
GemmM
/
GemmM1
;
const
index_t
GemmN0
=
GemmN
/
GemmN1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
make_tuple
(
DynamicUnMerge
<
2
>
{
make_multi_index
(
GemmM0
,
GemmM1
)},
DynamicUnMerge
<
2
>
{
make_multi_index
(
GemmN0
,
GemmN1
)}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// GEMM
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v1
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
true
,
// move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
run_gridwise_operation
<
gridwise_gemm
,
...
@@ -223,6 +583,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -223,6 +583,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
const
Float
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
...
@@ -236,6 +597,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -236,6 +597,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_in_global
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
integral_constant
<
bool
,
true
>
{});
}
}
else
else
...
@@ -248,6 +610,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -248,6 +610,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
const
Float
*
,
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
...
@@ -261,6 +624,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -261,6 +624,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_in_global
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
integral_constant
<
bool
,
false
>
{});
}
}
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
fa479ce4
...
@@ -73,7 +73,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -73,7 +73,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
}
}
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
,
bool
IsEvenNumberKBlockLoop
>
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_a_global
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
...
@@ -81,7 +85,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -81,7 +85,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
,
Float
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
IsEvenNumberKBlockLoop
>
)
const
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -264,88 +269,91 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -264,88 +269,91 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
}
}
#endif
#endif
#if 1
if
constexpr
(
HasMainKBlockLoop
)
Float
*
p_a_block_even
=
p_a_block_double
;
{
Float
*
p_b_block_even
=
p_b_block_double
;
Float
*
p_a_block_even
=
p_a_block_double
;
Float
*
p_b_block_even
=
p_b_block_double
;
Float
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
Float
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
Float
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
Float
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
// LDS double buffer: main body
index_t
k_block_data_begin
=
0
;
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
2
*
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
__syncthreads
();
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
__syncthreads
();
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on current data
// LDS doubel buffer: load next data from device mem
blockwise_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_c_thread
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: store next data to LDS
// LDS double buffer: GEMM on current data
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
blockwise_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_c_thread
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_odd
);
// odd iteration
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_
g
lo
bal
_desc
,
a_block_
slice_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_k_m_
b
lo
ck
_desc
,
p_
a_block_
odd
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_
g
lo
bal
_desc
,
b_block_
slice_copy_step
);
b_blockwise_copy
.
RunWrite
(
b_k_n_
b
lo
ck
_desc
,
p_
b_block_
odd
);
__syncthreads
();
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
__syncthreads
();
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on current data
// LDS doubel buffer: load next data from device mem
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_c_thread
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_even
);
// LDS double buffer: store next data to LDS
k_block_data_begin
+=
2
*
KPerBlock
;
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_even
);
}
}
#endif
#if 1
#if 1
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
if
constexpr
(
IsEvenNumberKBlockLoop
)
// if has 2 iteration left
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
{
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k_n_global_desc
,
b_block_slice_copy_step
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_b_block_double
+
b_block_space_size
);
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_c_thread
);
p_c_thread
);
}
}
else
// if has 1 iteration left
else
// if has 1 iteration left
{
{
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
}
#endif
#endif
...
@@ -398,14 +406,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -398,14 +406,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
}
}
}
}
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
,
bool
IsEvenNumberKBlockLoop
>
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_a_global
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
DynamicTensorDescriptor
<
BDesc
...
>&
b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
Float
*
__restrict__
p_b_global
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
const
DynamicTensorDescriptor
<
CDesc
...
>&
c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
IsEvenNumberKBlockLoop
>
)
const
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
Float
);
...
@@ -418,7 +431,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -418,7 +431,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
p_shared_block
,
p_shared_block
,
integral_constant
<
bool
,
IsEvenNumberKBlockLoop
>
{});
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
};
};
}
// namespace ck
}
// namespace ck
...
...
driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
fa479ce4
...
@@ -87,7 +87,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -87,7 +87,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif
0
#elif
1
// cdata = 64, BlockSize = 256, 128x128x4
// cdata = 64, BlockSize = 256, 128x128x4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -99,10 +99,10 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -99,10 +99,10 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
fa479ce4
...
@@ -54,6 +54,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -54,6 +54,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
const
auto
in_right_pads
=
to_multi_index
(
InRightPads
{});
const
auto
in_right_pads
=
to_multi_index
(
InRightPads
{});
#if 1
#if 1
// cdata = 64, BlockSize = 256, 128x128x4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
// cdata = 64, BlockSize = 256, 128x128x8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -107,29 +137,34 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -107,29 +137,34 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
conv_driver
=
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
<
constexpr
auto
conv_driver
=
BlockSize
,
#if 0 // debug
TDevice
,
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
TDevice
,
#else
GemmMPerBlock
,
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
GemmNPerBlock
,
#endif
GemmKPerBlock
,
<
BlockSize
,
GemmMPerThread
,
TDevice
,
GemmNPerThread
,
TDevice
,
GemmKPerThread
,
GemmMPerBlock
,
GemmMLevel0Cluster
,
GemmNPerBlock
,
GemmNLevel0Cluster
,
GemmKPerBlock
,
GemmMLevel1Cluster
,
GemmMPerThread
,
GemmNLevel1Cluster
,
GemmNPerThread
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmKPerThread
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmMLevel0Cluster
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmNLevel0Cluster
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmMLevel1Cluster
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmNLevel1Cluster
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
{};
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
{
...
...
driver/src/conv_driver.cpp
View file @
fa479ce4
...
@@ -22,22 +22,22 @@ int main(int argc, char* argv[])
...
@@ -22,22 +22,22 @@ int main(int argc, char* argv[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if
0
#if
1
//
1x1, 8x8
//
3x3, 35x35, stride 2
constexpr index_t N =
2
;
constexpr
index_t
N
=
128
;
constexpr index_t C = 2
4
;
constexpr
index_t
C
=
19
2
;
constexpr index_t HI =
8
;
constexpr
index_t
HI
=
35
;
constexpr index_t WI =
8
;
constexpr
index_t
WI
=
35
;
constexpr index_t K =
128
;
constexpr
index_t
K
=
384
;
constexpr index_t Y =
1
;
constexpr
index_t
Y
=
3
;
constexpr index_t X =
1
;
constexpr
index_t
X
=
3
;
using ConvStrides = Sequence<
1
,
1
>;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3, 71x71
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
C
=
192
;
...
@@ -127,7 +127,7 @@ int main(int argc, char* argv[])
...
@@ -127,7 +127,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
1
#elif
0
// 1x7, 17x17
// 1x7, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -217,7 +217,7 @@ int main(int argc, char* argv[])
...
@@ -217,7 +217,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3, 35x35, stride 2
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
288
;
constexpr
index_t
C
=
288
;
...
@@ -352,7 +352,7 @@ int main(int argc, char* argv[])
...
@@ -352,7 +352,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 3x3, 28x28
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -367,7 +367,7 @@ int main(int argc, char* argv[])
...
@@ -367,7 +367,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
0
#elif
1
// 3x3, 14x14
// 3x3, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -567,17 +567,17 @@ int main(int argc, char* argv[])
...
@@ -567,17 +567,17 @@ int main(int argc, char* argv[])
#if 0
#if 0
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
in_nchw,
wei_kcyx_desc,
wei_kcyx_desc,
wei_kcyx,
wei_kcyx,
out_nkhw_desc,
out_nkhw_desc,
out_nkhw_device,
out_nkhw_device,
ConvStrides{},
ConvStrides{},
ConvDilations{},
ConvDilations{},
LeftPads{},
LeftPads{},
RightPads{},
RightPads{},
nrepeat);
nrepeat);
#elif
1
#elif
0
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment