Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0be1cf14
"src/include/utility.hpp" did not exist on "e43d7bc63c2df138c376412fa5b4aaebc26ca131"
Commit
0be1cf14
authored
Jul 17, 2022
by
Chao Liu
Browse files
update conv bwd weight
parent
b054669b
Changes
30
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
177 additions
and
519 deletions
+177
-519
include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp
...gpu/device/convolution_backward_weight_specialization.hpp
+13
-0
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+1
-3
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+9
-2
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+8
-1
include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp
...e/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp
+8
-9
include/ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp
...ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp
+8
-8
include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp
include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp
+1
-2
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...ion/gpu/device/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+42
-25
include/ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp
...ice/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp
+67
-426
include/ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp
...peration/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp
+20
-43
No files found.
include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp
View file @
0be1cf14
...
...
@@ -15,6 +15,19 @@ enum struct ConvolutionBackwardWeightSpecialization
OddC
,
};
inline
std
::
string
getConvBackwardWeightSpecializationString
(
const
ConvolutionBackwardWeightSpecialization
&
s
)
{
switch
(
s
)
{
case
ConvolutionBackwardWeightSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
:
return
"Filter1x1Stride1Pad0"
;
case
ConvolutionBackwardWeightSpecialization
::
Filter1x1Pad0
:
return
"Filter1x1Pad0"
;
case
ConvolutionBackwardWeightSpecialization
::
OddC
:
return
"OddC"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
View file @
0be1cf14
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION
#pragma once
#include <string>
...
...
@@ -33,4 +32,3 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
0be1cf14
...
...
@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_b
ackwar
d_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_b
w
d_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
...
...
@@ -57,7 +57,14 @@ template <typename InDataType,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvBwdWeight
<
InElementwiseOperation
,
:
public
DeviceConvBwdWeight
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
0be1cf14
...
...
@@ -55,7 +55,14 @@ template <typename InDataType,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvBwdData
<
InElementwiseOperation
,
:
public
DeviceConvBwdData
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
...
...
include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp
View file @
0be1cf14
...
...
@@ -4,16 +4,21 @@
#pragma once
#include <vector>
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InElementwiseOperation
,
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceConvBwdData
:
public
BaseOperator
...
...
@@ -39,12 +44,6 @@ struct DeviceConvBwdData : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvBwdDataPtr
=
std
::
unique_ptr
<
DeviceConvBwdData
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_conv_b
ackwar
d_weight.hpp
→
include/ck/tensor_operation/gpu/device/device_conv_b
w
d_weight.hpp
View file @
0be1cf14
...
...
@@ -4,7 +4,6 @@
#pragma once
#include <vector>
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
...
...
@@ -12,7 +11,14 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InElementwiseOperation
,
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceConvBwdWeight
:
public
BaseOperator
...
...
@@ -39,12 +45,6 @@ struct DeviceConvBwdWeight : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvBwdWeightPtr
=
std
::
unique_ptr
<
DeviceConvBwdWeight
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp
View file @
0be1cf14
...
...
@@ -3,7 +3,6 @@
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
...
...
@@ -12,7 +11,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
ck
::
index_t
N
um
DimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
...
...
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_
xdl_ndh
wc_k
zy
xc_n
dh
wk.hpp
→
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_
n
wc_kxc_nwk
_xdl
.hpp
View file @
0be1cf14
...
...
@@ -21,7 +21,8 @@ namespace tensor_operation {
namespace
device
{
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
typename
InDataType
,
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
...
...
@@ -29,7 +30,6 @@ template <typename InDataType,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
...
...
@@ -55,12 +55,29 @@ template <typename InDataType,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
:
public
DeviceConvBwdData
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
struct
DeviceConvNdBwdDataNwcKxcNwk_Xdl
:
public
DeviceConvBwdData
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConv
n
dBwdData
Xdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
;
using
DeviceOp
=
DeviceConv
N
dBwdData
NwcKxcNwk_Xdl
;
using
ADataType
=
OutDataType
;
using
BDataType
=
WeiDataType
;
...
...
@@ -950,7 +967,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
{
0
,
0
,
0
});
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
N
um
DimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
...
...
@@ -1037,7 +1054,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
CreateABCDesc
<
N
um
DimSpatial
>
();
CreateABCDesc
<
NDimSpatial
>
();
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
...
...
@@ -1060,7 +1077,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
N
um
DimSpatial
>
(
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
...
...
@@ -1118,7 +1135,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
N
um
DimSpatial
>
(
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
...
...
@@ -1186,18 +1203,18 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NumDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
});
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
...
...
@@ -1398,7 +1415,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 pad = 0 conv
for
(
int
i
=
0
;
i
<
N
um
DimSpatial
;
i
++
)
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
...
...
@@ -1528,7 +1545,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceConv
n
dBwdData
Xdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
"
str
<<
"DeviceConv
N
dBwdData
NwcKxcNwk_Xdl
"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/device_convnd_b
ackwar
d_weight_xdl_c
_
shuffle
_nhwc_kyxc_nhwk
.hpp
→
include/ck/tensor_operation/gpu/device/device_convnd_b
w
d_weight_
nwc_kxc_nwk_
xdl_cshuffle.hpp
View file @
0be1cf14
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp
View file @
0be1cf14
...
...
@@ -39,7 +39,7 @@ namespace device {
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template
<
ck
::
index_t
N
um
DimSpatial
,
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
...
...
@@ -74,16 +74,16 @@ template <ck::index_t NumDimSpatial,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceConvNdFwdNwcKxcNwk_Xdl
:
public
DeviceConvFwd
<
N
um
DimSpatial
,
ck
::
tuple_element_t
<
N
um
DimSpatial
-
1
,
:
public
DeviceConvFwd
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
ck
::
tuple_element_t
<
N
um
DimSpatial
-
1
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
tuple_element_t
<
N
um
DimSpatial
-
1
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
...
...
@@ -94,27 +94,6 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
Base
=
DeviceConvFwd
<
NumDimSpatial
,
ck
::
tuple_element_t
<
NumDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
ck
::
tuple_element_t
<
NumDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
tuple_element_t
<
NumDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
;
using
DeviceOp
=
DeviceConvNdFwdNwcKxcNwk_Xdl
;
using
ADataType
=
InDataType
;
...
...
@@ -124,8 +103,6 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
NumDimSpatial
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -599,18 +576,18 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
// C = A^T*B
// A:
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
GetInputTensorDescriptor
<
N
um
DimSpatial
>
(
N
,
C
,
GemmMRaw
,
GemmK
,
GemmMPad
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
GetInputTensorDescriptor
<
NDimSpatial
>
(
N
,
C
,
GemmMRaw
,
GemmK
,
GemmMPad
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
// B:
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
GetWeightTensorDescriptor
(
GemmN
,
GemmK
);
// C:
...
...
@@ -642,7 +619,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
N
um
DimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
...
...
@@ -934,7 +911,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 conv
for
(
ck
::
index_t
i
=
0
;
i
<
N
um
DimSpatial
;
++
i
)
for
(
ck
::
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
...
...
@@ -947,7 +924,7 @@ struct DeviceConvNdFwdNwcKxcNwk_Xdl
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// check if it's 1x1 conv
for
(
ck
::
index_t
i
=
0
;
i
<
N
um
DimSpatial
;
++
i
)
for
(
ck
::
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment