Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3406a114
Unverified
Commit
3406a114
authored
Jan 27, 2020
by
Chao Liu
Committed by
GitHub
Jan 27, 2020
Browse files
Update for recent MIOpen integration (#11)
* update for MIOpen integration
parent
c5da0377
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
271 additions
and
199 deletions
+271
-199
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+1
-5
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
...a_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+8
-8
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+8
-8
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+56
-15
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+3
-4
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+10
-4
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+38
-37
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+0
-7
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+4
-3
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
...ble_kernel/include/utility/in_memory_operation.amd.hpp.in
+1
-1
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
..._kernel/include/utility/in_memory_operation.nvidia.hpp.in
+1
-1
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+9
-9
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+4
-4
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+4
-4
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+59
-62
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+40
-2
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+3
-3
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+15
-15
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+6
-6
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -49,7 +49,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
...
@@ -49,7 +49,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
const
Float
*
__restrict__
p_out_global
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
...
@@ -85,11 +84,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
...
@@ -85,11 +84,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
"be violated"
);
"be violated"
);
// output tensor
// output tensor
constexpr
auto
out_n_k_howo_global_desc
=
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
);
constexpr
auto
out_k_b_global_desc
=
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_howo_global_desc
,
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
out_n_k_ho
_
wo_global_desc
,
I2
,
I3
),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
3406a114
...
@@ -353,7 +353,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
...
@@ -353,7 +353,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
}
}
{
{
#if 1
// debug
#if 1 // debug
// input: register to global memory, atomic add
// input: register to global memory, atomic add
constexpr
auto
in_memory_op
=
(
Y
<=
ConvStrideH
&&
X
<=
ConvStrideW
)
constexpr
auto
in_memory_op
=
(
Y
<=
ConvStrideH
&&
X
<=
ConvStrideW
)
?
InMemoryDataOperation
::
none
?
InMemoryDataOperation
::
none
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -81,11 +81,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -81,11 +81,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
"be violated");
"be violated");
#endif
#endif
constexpr
index_t
hcf
_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
gcd
_stride_dilation_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf
_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
gcd
_stride_dilation_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf
_stride_dilation_h
;
constexpr
index_t
Ytilda
=
ConvStrideH
/
gcd
_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf
_stride_dilation_w
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
gcd
_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
...
@@ -115,10 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -115,10 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Embed
<
Y
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
ConvStrideH
/
hcf
_stride_dilation_h
,
1
,
0
>>
{},
Sequence
<
ConvStrideH
/
gcd
_stride_dilation_h
,
1
,
0
>>
{},
Embed
<
X
,
Embed
<
X
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
ConvStrideW
/
hcf
_stride_dilation_w
,
1
,
0
>>
{}),
Sequence
<
ConvStrideW
/
gcd
_stride_dilation_w
,
1
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -135,10 +135,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -135,10 +135,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Embed
<
Ho
,
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<-
ConvDilationH
/
hcf
_stride_dilation_h
,
1
,
0
>>
{},
Sequence
<-
ConvDilationH
/
gcd
_stride_dilation_h
,
1
,
0
>>
{},
Embed
<
Wo
,
Embed
<
Wo
,
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<-
ConvDilationW
/
hcf
_stride_dilation_w
,
1
,
0
>>
{}),
Sequence
<-
ConvDilationW
/
gcd
_stride_dilation_w
,
1
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -110,11 +110,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -110,11 +110,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
"be violated");
"be violated");
#endif
#endif
constexpr
index_t
hcf
_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
gcd
_stride_dilation_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf
_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
gcd
_stride_dilation_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf
_stride_dilation_h
;
constexpr
index_t
Ytilda
=
ConvStrideH
/
gcd
_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf
_stride_dilation_w
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
gcd
_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
...
@@ -146,11 +146,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -146,11 +146,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Embed
<
Y
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
ConvStrideH
/
hcf
_stride_dilation_h
,
1
,
0
>
,
Sequence
<
ConvStrideH
/
gcd
_stride_dilation_h
,
1
,
0
>
,
wei_skip_all_out_of_bound_check
>
{},
wei_skip_all_out_of_bound_check
>
{},
Embed
<
X
,
Embed
<
X
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
ConvStrideW
/
hcf
_stride_dilation_w
,
1
,
0
>
,
Sequence
<
ConvStrideW
/
gcd
_stride_dilation_w
,
1
,
0
>
,
wei_skip_all_out_of_bound_check
>
{}),
wei_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -168,11 +168,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
...
@@ -168,11 +168,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Embed
<
Ho
,
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<-
ConvDilationH
/
hcf
_stride_dilation_h
,
1
,
0
>
,
Sequence
<-
ConvDilationH
/
gcd
_stride_dilation_h
,
1
,
0
>
,
out_skip_all_out_of_bound_check
>
{},
out_skip_all_out_of_bound_check
>
{},
Embed
<
Wo
,
Embed
<
Wo
,
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<-
ConvDilationW
/
hcf
_stride_dilation_w
,
1
,
0
>
,
Sequence
<-
ConvDilationW
/
gcd
_stride_dilation_w
,
1
,
0
>
,
out_skip_all_out_of_bound_check
>
{}),
out_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -22,8 +22,6 @@ template <index_t GridSize,
...
@@ -22,8 +22,6 @@ template <index_t GridSize,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
,
typename
InRightPads
,
index_t
Iter_ytilda
,
index_t
Iter_xtilda
,
index_t
GemmMPerBlock
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmKPerBlock
,
...
@@ -47,9 +45,27 @@ template <index_t GridSize,
...
@@ -47,9 +45,27 @@ template <index_t GridSize,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
struct
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
{
{
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
__host__
__device__
static
constexpr
index_t
GetNumberOfGemm
()
const
Float
*
__restrict__
p_wei_global
,
{
const
Float
*
__restrict__
p_out_global
)
const
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
gcd_stride_dilation_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
gcd_stride_dilation_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
gcd_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
gcd_stride_dilation_w
;
return
Ytilda
*
Xtilda
;
}
template
<
index_t
iYTilda
,
index_t
iXTilda
>
__device__
static
void
RunImpl
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
{
{
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
...
@@ -83,11 +99,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -83,11 +99,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
"be violated");
"be violated");
#endif
#endif
constexpr
index_t
hcf
_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
gcd
_stride_dilation_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf
_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
gcd
_stride_dilation_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf
_stride_dilation_h
;
constexpr
index_t
Ytilda
=
ConvStrideH
/
gcd
_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf
_stride_dilation_w
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
gcd
_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
...
@@ -119,11 +135,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -119,11 +135,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Embed
<
Y
,
Embed
<
Y
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
ConvStrideH
/
hcf
_stride_dilation_h
,
1
,
0
>
,
Sequence
<
ConvStrideH
/
gcd
_stride_dilation_h
,
1
,
0
>
,
wei_skip_all_out_of_bound_check
>
{},
wei_skip_all_out_of_bound_check
>
{},
Embed
<
X
,
Embed
<
X
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
ConvStrideW
/
hcf
_stride_dilation_w
,
1
,
0
>
,
Sequence
<
ConvStrideW
/
gcd
_stride_dilation_w
,
1
,
0
>
,
wei_skip_all_out_of_bound_check
>
{}),
wei_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -141,11 +157,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -141,11 +157,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
Embed
<
Ho
,
Embed
<
Ho
,
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<-
ConvDilationH
/
hcf
_stride_dilation_h
,
1
,
0
>
,
Sequence
<-
ConvDilationH
/
gcd
_stride_dilation_h
,
1
,
0
>
,
out_skip_all_out_of_bound_check
>
{},
out_skip_all_out_of_bound_check
>
{},
Embed
<
Wo
,
Embed
<
Wo
,
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<-
ConvDilationW
/
hcf
_stride_dilation_w
,
1
,
0
>
,
Sequence
<-
ConvDilationW
/
gcd
_stride_dilation_w
,
1
,
0
>
,
out_skip_all_out_of_bound_check
>
{}),
out_skip_all_out_of_bound_check
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -215,8 +231,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -215,8 +231,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
5
>
{}));
// GEMM
// GEMM
constexpr
index_t
ytilda
=
Iter_yt
ilda
;
constexpr
index_t
ytilda
=
iYT
ilda
;
constexpr
index_t
xtilda
=
Iter_xt
ilda
;
constexpr
index_t
xtilda
=
iXT
ilda
;
constexpr
index_t
YdotNonZero
=
(
ytilda
+
1
)
*
Ydot
<=
Y
?
Ydot
:
Y
%
Ydot
;
constexpr
index_t
YdotNonZero
=
(
ytilda
+
1
)
*
Ydot
<=
Y
?
Ydot
:
Y
%
Ydot
;
constexpr
index_t
XdotNonZero
=
(
xtilda
+
1
)
*
Xdot
<=
X
?
Xdot
:
X
%
Xdot
;
constexpr
index_t
XdotNonZero
=
(
xtilda
+
1
)
*
Xdot
<=
X
?
Xdot
:
X
%
Xdot
;
...
@@ -327,6 +343,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
...
@@ -327,6 +343,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
}
template
<
index_t
GemmId
>
__device__
static
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
{
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
gcd_stride_dilation_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
gcd_stride_dilation_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
gcd_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
gcd_stride_dilation_w
;
constexpr
index_t
iYTilda
=
GemmId
/
Xtilda
;
constexpr
index_t
iXTilda
=
GemmId
%
Xtilda
;
static_assert
(
iYTilda
<
Ytilda
&&
iXTilda
<
Xtilda
,
"wrong! iYtilda, iXtilda"
);
RunImpl
<
iYTilda
,
iXTilda
>
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -49,7 +49,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -49,7 +49,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
Float
*
const
__restrict__
p_out_global
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
...
@@ -117,9 +116,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -117,9 +116,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// output tensor
// output tensor
constexpr
auto
out_k_b_global_desc
=
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
// GEMM
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
3406a114
...
@@ -47,6 +47,9 @@ struct PassThrough
...
@@ -47,6 +47,9 @@ struct PassThrough
}
}
};
};
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// LowerLengths: Sequence<...>
// LowerLengths: Sequence<...>
template
<
typename
LowerLengths
,
template
<
typename
LowerLengths
,
typename
LeftPads
,
typename
LeftPads
,
...
@@ -92,12 +95,12 @@ struct Pad
...
@@ -92,12 +95,12 @@ struct Pad
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
{
#if 1 // debug
// skip valid check if user request it
if
(
SkipIsValidCheck
)
if
(
SkipIsValidCheck
)
{
{
return
true
;
return
true
;
}
}
#endif
bool
flag
=
true
;
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
nDim
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nDim
;
++
i
)
...
@@ -384,6 +387,9 @@ struct UnMerge
...
@@ -384,6 +387,9 @@ struct UnMerge
}
}
};
};
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// UpperLengths: Sequence<...>
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
...
@@ -442,12 +448,12 @@ struct Embed
...
@@ -442,12 +448,12 @@ struct Embed
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
{
#if 1 // debug
// skip valid check if user request it
if
(
SkipIsValidCheck
)
if
(
SkipIsValidCheck
)
{
{
return
true
;
return
true
;
}
}
#endif
bool
flag
=
true
;
bool
flag
=
true
;
index_t
ncorner
=
1
;
index_t
ncorner
=
1
;
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
3406a114
...
@@ -112,11 +112,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -112,11 +112,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
{
move
_data
<
SrcData
,
transfer
_data
<
SrcData
,
SrcDataPerRead
,
SrcDataPerRead
,
SrcAddressSpace
,
SrcAddressSpace
,
AddressSpace
::
vgpr
,
AddressSpace
::
vgpr
,
InMemoryDataOperation
::
none
>
(
InMemoryDataOperation
::
none
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
buffer_offset
);
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
buffer_offset
);
}
}
}
}
...
@@ -144,11 +144,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -144,11 +144,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
// has the valid/invalid mapping situation
if
(
dst_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
if
(
dst_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
{
move
_data
<
DstData
,
transfer
_data
<
DstData
,
DstDataPerWrite
,
DstDataPerWrite
,
AddressSpace
::
vgpr
,
AddressSpace
::
vgpr
,
DstAddressSpace
,
DstAddressSpace
,
DstInMemOp
>
(
DstInMemOp
>
(
p_dst_long_vector
,
buffer_offset
,
p_dst
,
dst_coord
.
GetOffset
());
p_dst_long_vector
,
buffer_offset
,
p_dst
,
dst_coord
.
GetOffset
());
}
}
}
}
...
@@ -262,15 +262,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -262,15 +262,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
{
move
_data
<
SrcData
,
transfer
_data
<
SrcData
,
SrcDataPerRead
,
SrcDataPerRead
,
SrcAddressSpace
,
SrcAddressSpace
,
AddressSpace
::
vgpr
,
AddressSpace
::
vgpr
,
InMemoryDataOperation
::
none
>
(
p_src
,
InMemoryDataOperation
::
none
>
(
p_src
,
src_nonlinear_coord
.
GetOffset
()
+
src_nonlinear_coord
.
GetOffset
()
+
src_linear_offset
,
src_linear_offset
,
p_src_long_vector
,
p_src_long_vector
,
buffer_offset
);
buffer_offset
);
}
}
}
}
...
@@ -301,11 +301,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -301,11 +301,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
// has the valid/invalid mapping situation
if
(
dst_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
if
(
dst_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
{
move
_data
<
DstData
,
transfer
_data
<
DstData
,
DstDataPerWrite
,
DstDataPerWrite
,
AddressSpace
::
vgpr
,
AddressSpace
::
vgpr
,
DstAddressSpace
,
DstAddressSpace
,
DstInMemOp
>
(
DstInMemOp
>
(
p_dst_long_vector
,
buffer_offset
,
p_dst
,
dst_coord
.
GetOffset
());
p_dst_long_vector
,
buffer_offset
,
p_dst
,
dst_coord
.
GetOffset
());
}
}
}
}
...
@@ -401,11 +401,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -401,11 +401,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
// has the valid/invalid mapping situation
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
if
(
src_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
{
move
_data
<
SrcData
,
transfer
_data
<
SrcData
,
SrcDataPerRead
,
SrcDataPerRead
,
SrcAddressSpace
,
SrcAddressSpace
,
AddressSpace
::
vgpr
,
AddressSpace
::
vgpr
,
InMemoryDataOperation
::
none
>
(
InMemoryDataOperation
::
none
>
(
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
buffer_offset
);
p_src
,
src_coord
.
GetOffset
(),
p_src_long_vector
,
buffer_offset
);
}
}
}
}
...
@@ -446,14 +446,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -446,14 +446,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
// has the valid/invalid mapping situation
if
(
dst_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
if
(
dst_coord
.
IsOffsetValidAssumingUpperIndexIsValid
())
{
{
move_data
<
DstData
,
transfer_data
<
DstData
,
DstDataPerWrite
,
DstDataPerWrite
,
AddressSpace
::
vgpr
,
AddressSpace
::
vgpr
,
DstAddressSpace
,
DstAddressSpace
,
DstInMemOp
>
(
p_dst_long_vector
,
DstInMemOp
>
(
p_dst_long_vector
,
buffer_offset
,
buffer_offset
,
p_dst
,
p_dst
,
dst_nonlinear_coord
.
GetOffset
()
+
dst_linear_offset
);
dst_nonlinear_coord
.
GetOffset
()
+
dst_linear_offset
);
}
}
}
}
});
});
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
3406a114
...
@@ -8,19 +8,12 @@ namespace ck {
...
@@ -8,19 +8,12 @@ namespace ck {
// outer-product: c[i,j] += inner_product(a[i], b[j])
// outer-product: c[i,j] += inner_product(a[i], b[j])
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
{
{
// disable inline asm due to the compiler issue: SWDEV-202749
///\to-do: enable the inline asm after the compiler fix
#if CK_WORKAROUND_SWDEV_202749
c0
+=
a
*
b0
;
c1
+=
a
*
b1
;
#else
asm
volatile
(
"
\n
\
asm
volatile
(
"
\n
\
v_mac_f32 %0, %2, %3
\n
\
v_mac_f32 %0, %2, %3
\n
\
v_mac_f32 %1, %2, %4
\n
\
v_mac_f32 %1, %2, %4
\n
\
"
"
:
"=v"
(
c0
),
"=v"
(
c1
)
:
"=v"
(
c0
),
"=v"
(
c1
)
:
"v"
(
a
),
"v"
(
b0
),
"v"
(
b1
),
"0"
(
c0
),
"1"
(
c1
));
:
"v"
(
a
),
"v"
(
b0
),
"v"
(
b1
),
"0"
(
c0
),
"1"
(
c1
));
#endif
}
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
// outer-product: c[i,j] += inner_product(a[i], b[j])
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
3406a114
...
@@ -43,6 +43,10 @@
...
@@ -43,6 +43,10 @@
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
#endif
#endif
#ifndef CK_USE_AMD_XDLOPS_EMULATE
#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes
#endif
// experimental implementation
// experimental implementation
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
...
@@ -51,9 +55,6 @@
...
@@ -51,9 +55,6 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
// workaround
#define CK_WORKAROUND_SWDEV_202749 1
namespace ck {
namespace ck {
enum AddressSpace
enum AddressSpace
...
...
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
View file @
3406a114
...
@@ -70,7 +70,7 @@ template <typename T,
...
@@ -70,7 +70,7 @@ template <typename T,
AddressSpace SrcAddressSpace,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp>
InMemoryDataOperation DstInMemOp>
__device__ void
move
_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
__device__ void
transfer
_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
{
static_assert(DstInMemOp == InMemoryDataOperation::none ||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
DstInMemOp == InMemoryDataOperation::atomic_add,
DstInMemOp == InMemoryDataOperation::atomic_add,
...
...
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
View file @
3406a114
...
@@ -38,7 +38,7 @@ template <typename T,
...
@@ -38,7 +38,7 @@ template <typename T,
AddressSpace SrcAddressSpace,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp>
InMemoryDataOperation DstInMemOp>
__device__ void
move
_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
__device__ void
transfer
_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
{
static_assert(DstInMemOp == InMemoryDataOperation::none ||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
DstInMemOp == InMemoryDataOperation::atomic_add,
DstInMemOp == InMemoryDataOperation::atomic_add,
...
...
composable_kernel/include/utility/math.hpp
View file @
3406a114
...
@@ -103,9 +103,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
...
@@ -103,9 +103,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
return
x
<
y
?
x
:
y
;
return
x
<
y
?
x
:
y
;
}
}
// highest common factor
//
greatest common divisor, aka
highest common factor
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
T
hcf
(
T
x
,
T
y
)
__host__
__device__
constexpr
T
gcd
(
T
x
,
T
y
)
{
{
if
(
x
==
0
)
if
(
x
==
0
)
{
{
...
@@ -124,30 +124,30 @@ __host__ __device__ constexpr T hcf(T x, T y)
...
@@ -124,30 +124,30 @@ __host__ __device__ constexpr T hcf(T x, T y)
if
(
x
>
y
)
if
(
x
>
y
)
{
{
return
hcf
(
x
-
y
,
y
);
return
gcd
(
x
-
y
,
y
);
}
}
return
hcf
(
x
,
y
-
x
);
return
gcd
(
x
,
y
-
x
);
}
}
template
<
index_t
X
,
index_t
Y
>
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
hcf
(
Number
<
X
>
,
Number
<
Y
>
)
__host__
__device__
constexpr
auto
gcd
(
Number
<
X
>
,
Number
<
Y
>
)
{
{
constexpr
auto
result
=
hcf
(
X
,
Y
);
constexpr
auto
result
=
gcd
(
X
,
Y
);
return
Number
<
result
>
{};
return
Number
<
result
>
{};
}
}
template
<
typename
X
,
typename
...
Ys
>
template
<
typename
X
,
typename
...
Ys
>
__host__
__device__
constexpr
auto
hcf
(
X
x
,
Ys
...
ys
)
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
{
{
return
hcf
(
x
,
ys
...);
return
gcd
(
x
,
ys
...);
}
}
// least common multiple
// least common multiple
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
T
lcm
(
T
x
,
T
y
)
__host__
__device__
constexpr
T
lcm
(
T
x
,
T
y
)
{
{
return
(
x
*
y
)
/
hcf
(
x
,
y
);
return
(
x
*
y
)
/
gcd
(
x
,
y
);
}
}
template
<
typename
X
,
typename
Y
,
typename
...
Zs
>
template
<
typename
X
,
typename
Y
,
typename
...
Zs
>
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -152,11 +152,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
...
@@ -152,11 +152,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#endif
#endif
constexpr
index_t
hcf
_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
gcd
_stride_dilation_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf
_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
gcd
_stride_dilation_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf
_stride_dilation_h
;
constexpr
index_t
Ytilda
=
ConvStrideH
/
gcd
_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf
_stride_dilation_w
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
gcd
_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -91,11 +91,11 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
...
@@ -91,11 +91,11 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
#endif
constexpr
index_t
hcf
_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
gcd
_stride_dilation_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf
_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
gcd
_stride_dilation_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf
_stride_dilation_h
;
constexpr
index_t
Ytilda
=
ConvStrideH
/
gcd
_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf
_stride_dilation_w
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
gcd
_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -2,13 +2,18 @@
...
@@ -2,13 +2,18 @@
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
namespace
launcher
{
using
namespace
ck
;
using
namespace
ck
;
template
<
typename
GridwiseOp
,
index_t
GemmId
,
typename
...
Xs
>
__global__
void
run_gridwise_convolution_backward_data_v4r1
(
Xs
...
xs
)
{
GridwiseOp
::
template
Run
<
GemmId
>(
xs
...);
}
template
<
typename
T
,
template
<
typename
T
,
typename
InDesc
,
typename
InDesc
,
typename
WeiDesc
,
typename
WeiDesc
,
...
@@ -119,11 +124,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
...
@@ -119,11 +124,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
#endif
constexpr
index_t
hcf
_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
gcd
_stride_dilation_h
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf
_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
gcd
_stride_dilation_w
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf
_stride_dilation_h
;
constexpr
index_t
Ytilda
=
ConvStrideH
/
gcd
_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf
_stride_dilation_w
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
gcd
_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
...
@@ -154,69 +159,61 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
...
@@ -154,69 +159,61 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
KernelTimer
timer
;
using
GridwiseConv
=
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
;
KernelTimer
timer
;
timer
.
Start
();
timer
.
Start
();
static_for
<
0
,
Ytilda
,
1
>
{}([
&
](
auto
ytilda_
)
{
static_for
<
0
,
GridwiseConv
::
GetNumberOfGemm
(),
1
>
{}([
&
](
auto
gemm_id_
)
{
static_for
<
0
,
Xtilda
,
1
>
{}([
&
](
auto
xtilda_
)
{
constexpr
index_t
gemm_id
=
decltype
(
gemm_id_
){};
constexpr
index_t
ytilda
=
decltype
(
ytilda_
){};
constexpr
index_t
xtilda
=
decltype
(
xtilda_
){};
launch_kernel
(
run_gridwise_convolution_backward_data_v4r1
<
GridwiseConv
,
gemm_id
,
constexpr
auto
gridwise_conv
=
T
*
const
__restrict__
,
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
<
const
T
*
const
__restrict__
,
GridSize
,
const
T
*
const
__restrict__
>
,
BlockSize
,
dim3
(
GridSize
),
T
,
dim3
(
BlockSize
),
T
,
0
,
decltype
(
in_nchw_desc
),
0
,
decltype
(
wei_kcyx_desc
),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
decltype
(
out_nkhw_desc
),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
ConvStrides
,
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
ConvDilations
,
InLeftPads
,
InRightPads
,
ytilda
,
xtilda
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
});
});
});
timer
.
End
();
timer
.
End
();
float
time
=
timer
.
GetElapsedTime
();
float
time
=
timer
.
GetElapsedTime
();
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
1
#if
0
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 256;
...
@@ -127,7 +127,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -127,7 +127,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
2
;
#elif 1
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
// for 1x1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
16
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
1
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
16
,
2
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
2
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
2
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
4
,
64
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
2
;
#elif 1
#elif 1
// BlockSize = 64, each thread hold 64 data
// BlockSize = 64, each thread hold 64 data
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
3406a114
...
@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif
1
#elif
0
// BlockSize = 256, GemmKPerBlock = 16
// BlockSize = 256, GemmKPerBlock = 16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -117,7 +117,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -117,7 +117,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 0
#elif 0
// BlockSize = 256, GemmKPerBlock = 8
// BlockSize = 256, GemmKPerBlock = 8
// 1x1 filter,
8x8 image
//
for
1x1 filter,
vector-read-b = 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
@@ -149,7 +149,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -149,7 +149,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#elif 1
#elif 1
// BlockSize = 256, GemmKPerBlock = 16
// BlockSize = 256, GemmKPerBlock = 16
// 1x1 filter,
8x8 image
//
for
1x1 filter,
vector-read-b = 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
3406a114
...
@@ -161,10 +161,10 @@ int main(int argc, char* argv[])
...
@@ -161,10 +161,10 @@ int main(int argc, char* argv[])
#elif 1
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
1
024
;
constexpr
index_t
K
=
1
28
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
constexpr
index_t
X
=
7
;
...
@@ -246,28 +246,28 @@ int main(int argc, char* argv[])
...
@@ -246,28 +246,28 @@ int main(int argc, char* argv[])
#endif
#endif
}
}
#if
0
#if
1
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0
#elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif
1
#elif
0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0
#elif 0
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#elif 1
#elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif
#endif
(
in_nchw_desc
,
(
in_nchw_desc
,
in_nchw_device
,
in_nchw_device
,
wei_kcyx_desc
,
wei_kcyx_desc
,
wei_kcyx
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_desc
,
out_nkhw
,
out_nkhw
,
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
if
(
do_verification
)
if
(
do_verification
)
{
{
...
...
driver/src/conv_driver.cpp
View file @
3406a114
...
@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
...
@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if
0
#if
1
// 1x1
// 1x1
constexpr index_t N =
25
6;
constexpr
index_t
N
=
6
4
;
constexpr index_t C =
102
4;
constexpr
index_t
C
=
6
4
;
constexpr index_t HI =
8
;
constexpr
index_t
HI
=
56
;
constexpr index_t WI =
8
;
constexpr
index_t
WI
=
56
;
constexpr index_t K =
1024
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment