Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a6d0b5e1
Commit
a6d0b5e1
authored
Dec 27, 2022
by
Rosty Geyyer
Browse files
Update conv_bwd_weight_dl to grouped_conv_bwd_weight_dl
parent
05ee41c3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
238 additions
and
62 deletions
+238
-62
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+3
-3
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+177
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_nwc_kxc_nwk_dl.hpp
...ce/impl/device_grouped_conv_bwd_weight_nwc_kxc_nwk_dl.hpp
+58
-52
No files found.
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
a6d0b5e1
...
@@ -7,8 +7,8 @@ add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd
...
@@ -7,8 +7,8 @@ add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
add_dependencies
(
example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
example_grouped_conv_bwd_weight_xdl_bf16
)
example_grouped_conv_bwd_weight_xdl_bf16
)
add_custom_target
(
example_conv_bwd_weight
)
add_custom_target
(
example_
grouped_
conv_bwd_weight
_dl
)
add_example_executable
(
example_conv
nd
_bwd_weight_dl_fp16 conv
nd
_bwd_weight_dl_fp16.cpp
)
add_example_executable
(
example_
grouped_
conv_bwd_weight_dl_fp16
grouped_
conv_bwd_weight_dl_fp16.cpp
)
add_dependencies
(
example_conv_bwd_weight example_conv
nd
_bwd_weight_dl_fp16
)
add_dependencies
(
example_
grouped_
conv_bwd_weight
_dl
example_
grouped_
conv_bwd_weight_dl_fp16
)
example/20_grouped_conv_bwd_weight/conv
nd
_bwd_weight_dl_fp16.cpp
→
example/20_grouped_conv_bwd_weight/
grouped_
conv_bwd_weight_dl_fp16.cpp
View file @
a6d0b5e1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "c
onvnd_bwd_weight_common
.hpp"
#include "c
k/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_nwc_kxc_nwk_dl
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_weight_nwc_kxc_nwk_dl.hpp"
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp"
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
...
@@ -21,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
...
@@ -21,7 +36,7 @@ static constexpr auto ConvBwdWeightDefault =
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
::
Default
;
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
::
Default
;
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConv
nd
BwdWeightInstance
=
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceConvNdBwdWeightNwcKxcNwk_Dl
<
ck
::
tensor_operation
::
device
::
DeviceConvNdBwdWeightNwcKxcNwk_Dl
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
// NDimSpatial
InDataType
,
// InDataType
InDataType
,
// InDataType
...
@@ -60,6 +75,161 @@ using DeviceConvndBwdWeightInstance =
...
@@ -60,6 +75,161 @@ using DeviceConvndBwdWeightInstance =
5
,
// CThreadTransferSrcDstVectorDim
5
,
// CThreadTransferSrcDstVectorDim
4
>
;
// CThreadTransferDstScalarPerVector
4
>
;
// CThreadTransferDstScalarPerVector
template
<
ck
::
index_t
NDimSpatial
>
using
HostConvBwdWeightInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdWeight
<
NDimSpatial
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InElementOp
,
typename
WeiElementOp
,
typename
OutElementOp
,
typename
DeviceConvBwdWeightInstance
>
int
run_conv_bwd_weight
(
bool
do_verification
,
int
init_method
,
bool
time_kernel
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
,
const
HostTensorDescriptor
&
in_g_n_c_wis_desc
,
const
HostTensorDescriptor
&
wei_g_k_c_xs_desc
,
const
HostTensorDescriptor
&
out_g_n_k_wos_desc
,
const
InElementOp
&
in_element_op
,
const
WeiElementOp
&
wei_element_op
,
const
OutElementOp
&
out_element_op
,
ck
::
index_t
split_k
)
{
Tensor
<
InDataType
>
in
(
in_g_n_c_wis_desc
);
Tensor
<
WeiDataType
>
wei_host_result
(
wei_g_k_c_xs_desc
);
Tensor
<
WeiDataType
>
wei_device_result
(
wei_g_k_c_xs_desc
);
Tensor
<
OutDataType
>
out
(
out_g_n_k_wos_desc
);
std
::
cout
<<
"in: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out: "
<<
out
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
out
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
out
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out
.
mDesc
.
GetElementSpaceSize
());
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
out_device_buf
.
ToDevice
(
out
.
mData
.
data
());
// init to 0
wei_device_buf
.
SetZero
();
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
range_copy
(
conv_param
.
input_right_pads_
,
begin
(
input_right_pads
));
// do GEMM
auto
conv
=
DeviceConvBwdWeightInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
C_
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
,
split_k
);
if
(
!
conv
.
IsSupportedArgument
(
argument
))
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
float
avg_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
conv_param
.
GetFlops
();
std
::
size_t
num_btype
=
conv_param
.
GetByte
<
InDataType
,
WeiDataType
,
OutDataType
>
();
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
conv
.
GetTypeString
()
<<
std
::
endl
;
if
(
do_verification
)
{
auto
ref_conv
=
HostConvBwdWeightInstance
<
NDimSpatial
>
{};
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in
,
wei_host_result
,
out
,
conv_param
.
conv_filter_strides_
,
conv_param
.
conv_filter_dilations_
,
conv_param
.
input_left_pads_
,
conv_param
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
ref_invoker
.
Run
(
ref_argument
);
wei_device_buf
.
FromDevice
(
wei_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
wei_device_result
.
mData
,
wei_host_result
.
mData
)
?
0
:
1
;
}
return
0
;
}
void
print_helper_msg
()
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
ck
::
utils
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
...
@@ -71,7 +241,7 @@ int main(int argc, char* argv[])
...
@@ -71,7 +241,7 @@ int main(int argc, char* argv[])
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
ck
::
utils
::
conv
::
ConvParam
conv_param
{
ck
::
utils
::
conv
::
ConvParam
conv_param
{
2
,
1
,
32
,
256
,
1024
,
{
3
,
3
},
{
14
,
14
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}};
2
,
4
,
1
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}};
ck
::
index_t
split_k
=
1
;
ck
::
index_t
split_k
=
1
;
...
@@ -124,7 +294,7 @@ int main(int argc, char* argv[])
...
@@ -124,7 +294,7 @@ int main(int argc, char* argv[])
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
OutElementOp
,
OutElementOp
,
DeviceConv
nd
BwdWeightInstance
<
1
>>
(
do_verification
,
DeviceConvBwdWeightInstance
<
1
>>
(
do_verification
,
init_method
,
init_method
,
time_kernel
,
time_kernel
,
conv_param
,
conv_param
,
...
@@ -161,7 +331,7 @@ int main(int argc, char* argv[])
...
@@ -161,7 +331,7 @@ int main(int argc, char* argv[])
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
OutElementOp
,
OutElementOp
,
DeviceConv
nd
BwdWeightInstance
<
2
>>
(
do_verification
,
DeviceConvBwdWeightInstance
<
2
>>
(
do_verification
,
init_method
,
init_method
,
time_kernel
,
time_kernel
,
conv_param
,
conv_param
,
...
@@ -198,7 +368,7 @@ int main(int argc, char* argv[])
...
@@ -198,7 +368,7 @@ int main(int argc, char* argv[])
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
OutElementOp
,
OutElementOp
,
DeviceConv
nd
BwdWeightInstance
<
3
>>
(
do_verification
,
DeviceConvBwdWeightInstance
<
3
>>
(
do_verification
,
init_method
,
init_method
,
time_kernel
,
time_kernel
,
conv_param
,
conv_param
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv
nd
_bwd_weight_nwc_kxc_nwk_dl.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_
grouped_
conv_bwd_weight_nwc_kxc_nwk_dl.hpp
View file @
a6d0b5e1
...
@@ -10,9 +10,8 @@
...
@@ -10,9 +10,8 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_
grouped_
conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -58,7 +57,7 @@ template <ck::index_t NDimSpatial,
...
@@ -58,7 +57,7 @@ template <ck::index_t NDimSpatial,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceConvNdBwdWeightNwcKxcNwk_Dl
struct
DeviceConvNdBwdWeightNwcKxcNwk_Dl
:
public
DeviceConvBwdWeight
<
:
public
Device
Grouped
ConvBwdWeight
<
NDimSpatial
,
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
...
@@ -121,13 +120,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -121,13 +120,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
ck
::
index_t
batch_k
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -273,13 +272,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -273,13 +272,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
ck
::
index_t
batch_k
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -441,13 +440,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -441,13 +440,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
ck
::
index_t
batch_k
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -707,16 +706,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -707,16 +706,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
Argument
(
const
InDataType
*
p_in_grid
,
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
...
@@ -729,6 +729,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -729,6 +729,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
a_element_op_
{
out_element_op
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
c_element_op_
{
in_element_op
},
Conv_G_
{
G
},
Conv_N_
{
N
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_K_
{
K
},
Conv_C_
{
C
},
Conv_C_
{
C
},
...
@@ -786,17 +787,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -786,17 +787,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
InElementwiseOperation
c_element_op_
;
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
// for checking IsSupportedArgument()
index_t
Conv_G_
;
index_t
Conv_N_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
index_t
Conv_C_
;
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
index_t
>
input_left_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
vector
<
ck
::
index_t
>
input_right_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads_
;
index_t
k_batch_
;
index_t
k_batch_
;
};
};
...
@@ -996,16 +998,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -996,16 +998,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
@@ -1014,6 +1017,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -1014,6 +1017,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
return
Argument
{
p_in_grid
,
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_wei_grid
,
p_out_grid
,
p_out_grid
,
G
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -1035,16 +1039,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -1035,16 +1039,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
@@ -1053,6 +1058,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
...
@@ -1053,6 +1058,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
N
,
N
,
K
,
K
,
C
,
C
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment