Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2a09df0f
Commit
2a09df0f
authored
Aug 04, 2023
by
Bartlomiej Kocot
Browse files
Unify backward weight api with forward
parent
23220fe5
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
212 additions
and
239 deletions
+212
-239
client_example/11_grouped_conv_bwd_weight/common.hpp
client_example/11_grouped_conv_bwd_weight/common.hpp
+39
-40
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+9
-13
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+6
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+72
-80
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+77
-83
profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
...include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
+9
-13
No files found.
client_example/11_grouped_conv_bwd_weight/common.hpp
View file @
2a09df0f
...
@@ -36,17 +36,18 @@ std::size_t GetFlops(ck::index_t G,
...
@@ -36,17 +36,18 @@ std::size_t GetFlops(ck::index_t G,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_
spatial_
lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_
spatial_
lengths
)
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_lengths
)
{
{
constexpr
index_t
spatial_offset
=
3
;
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return
static_cast
<
std
::
size_t
>
(
2
)
*
G
*
N
*
K
*
C
*
return
static_cast
<
std
::
size_t
>
(
2
)
*
G
*
N
*
K
*
C
*
std
::
accumulate
(
std
::
begin
(
output_
spatial_
lengths
),
std
::
accumulate
(
std
::
begin
(
output_lengths
)
+
spatial_offset
,
std
::
end
(
output_
spatial_
lengths
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
())
*
std
::
multiplies
<>
())
*
std
::
accumulate
(
std
::
begin
(
filter_
spatial_
lengths
),
std
::
accumulate
(
std
::
begin
(
filter_lengths
)
+
spatial_offset
,
std
::
end
(
filter_
spatial_
lengths
),
std
::
end
(
filter_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
std
::
multiplies
<>
());
}
}
...
@@ -55,12 +56,13 @@ template <typename InDataType, ck::index_t NumDimSpatial>
...
@@ -55,12 +56,13 @@ template <typename InDataType, ck::index_t NumDimSpatial>
std
::
size_t
GetInputByte
(
ck
::
index_t
G
,
std
::
size_t
GetInputByte
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_
spatial_
lengths
)
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_lengths
)
{
{
constexpr
index_t
spatial_offset
=
3
;
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
(
G
*
N
*
C
*
return
sizeof
(
InDataType
)
*
(
G
*
N
*
C
*
std
::
accumulate
(
std
::
begin
(
input_
spatial_
lengths
),
std
::
accumulate
(
std
::
begin
(
input_lengths
)
+
spatial_offset
,
std
::
end
(
input_
spatial_
lengths
),
std
::
end
(
input_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
()));
std
::
multiplies
<>
()));
}
}
...
@@ -69,12 +71,13 @@ template <typename WeiDataType, ck::index_t NumDimSpatial>
...
@@ -69,12 +71,13 @@ template <typename WeiDataType, ck::index_t NumDimSpatial>
std
::
size_t
GetWeightByte
(
ck
::
index_t
G
,
std
::
size_t
GetWeightByte
(
ck
::
index_t
G
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_
spatial_
lengths
)
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_lengths
)
{
{
constexpr
index_t
spatial_offset
=
3
;
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
(
G
*
K
*
C
*
return
sizeof
(
WeiDataType
)
*
(
G
*
K
*
C
*
std
::
accumulate
(
std
::
begin
(
filter_
spatial_
lengths
),
std
::
accumulate
(
std
::
begin
(
filter_lengths
)
+
spatial_offset
,
std
::
end
(
filter_
spatial_
lengths
),
std
::
end
(
filter_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
()));
std
::
multiplies
<>
()));
}
}
...
@@ -83,12 +86,13 @@ template <typename OutDataType, ck::index_t NumDimSpatial>
...
@@ -83,12 +86,13 @@ template <typename OutDataType, ck::index_t NumDimSpatial>
std
::
size_t
GetOutputByte
(
ck
::
index_t
G
,
std
::
size_t
GetOutputByte
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_
spatial_
lengths
)
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_lengths
)
{
{
constexpr
index_t
spatial_offset
=
3
;
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
(
G
*
N
*
K
*
return
sizeof
(
OutDataType
)
*
(
G
*
N
*
K
*
std
::
accumulate
(
std
::
begin
(
output_
spatial_
lengths
),
std
::
accumulate
(
std
::
begin
(
output_lengths
)
+
spatial_offset
,
std
::
end
(
output_
spatial_
lengths
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
std
::
multiplies
<
std
::
size_t
>
()));
}
}
...
@@ -105,9 +109,9 @@ bool run_grouped_conv_bwd_weight(
...
@@ -105,9 +109,9 @@ bool run_grouped_conv_bwd_weight(
const
ck
::
index_t
N
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_
spatial_
lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
input_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_
spatial_
lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
filter_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_
spatial_
lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_strides
,
...
@@ -118,9 +122,9 @@ bool run_grouped_conv_bwd_weight(
...
@@ -118,9 +122,9 @@ bool run_grouped_conv_bwd_weight(
{
{
ck
::
index_t
split_k
=
2
;
ck
::
index_t
split_k
=
2
;
SimpleDeviceMem
in
(
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
G
,
N
,
C
,
input_spatial
_lengths
));
SimpleDeviceMem
in
(
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
input
_lengths
));
SimpleDeviceMem
wei
(
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
G
,
K
,
C
,
filter_spatial
_lengths
));
SimpleDeviceMem
wei
(
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
filter
_lengths
));
SimpleDeviceMem
out
(
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
G
,
N
,
K
,
output_spatial
_lengths
));
SimpleDeviceMem
out
(
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
output
_lengths
));
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
InLayout
,
...
@@ -144,6 +148,10 @@ bool run_grouped_conv_bwd_weight(
...
@@ -144,6 +148,10 @@ bool run_grouped_conv_bwd_weight(
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
float
best_tflops
=
0
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
// profile device operation instances
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
...
@@ -153,15 +161,11 @@ bool run_grouped_conv_bwd_weight(
...
@@ -153,15 +161,11 @@ bool run_grouped_conv_bwd_weight(
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
out
.
GetDeviceBuffer
(),
out
.
GetDeviceBuffer
(),
G
,
input_lengths
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
input_strides
,
filter_lengths
,
weights_strides
,
weights_strides
,
output_lengths
,
output_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
...
@@ -178,12 +182,11 @@ bool run_grouped_conv_bwd_weight(
...
@@ -178,12 +182,11 @@ bool run_grouped_conv_bwd_weight(
{
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
std
::
size_t
flop
=
GetFlops
<
NumDimSpatial
>
(
G
,
N
,
K
,
C
,
output_lengths
,
filter_lengths
);
GetFlops
<
NumDimSpatial
>
(
G
,
N
,
K
,
C
,
output_spatial_lengths
,
filter_spatial_lengths
);
std
::
size_t
num_bytes
=
std
::
size_t
num_bytes
=
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
G
,
N
,
C
,
input_
spatial_
lengths
)
+
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
G
,
N
,
C
,
input_lengths
)
+
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
G
,
K
,
C
,
filter_
spatial_
lengths
)
+
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
G
,
K
,
C
,
filter_lengths
)
+
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
G
,
N
,
K
,
output_
spatial_
lengths
);
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
G
,
N
,
K
,
output_lengths
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
...
@@ -223,15 +226,11 @@ bool run_grouped_conv_bwd_weight(
...
@@ -223,15 +226,11 @@ bool run_grouped_conv_bwd_weight(
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
out
.
GetDeviceBuffer
(),
out
.
GetDeviceBuffer
(),
G
,
input_lengths
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
input_strides
,
filter_lengths
,
weights_strides
,
weights_strides
,
output_lengths
,
output_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
...
...
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
2a09df0f
...
@@ -72,11 +72,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -72,11 +72,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
// init to 0
// init to 0
wei_device_buf
.
SetZero
();
wei_device_buf
.
SetZero
();
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
filter_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
...
@@ -85,11 +85,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -85,11 +85,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
begin
(
filter_spatial_lengths
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
begin
(
output_spatial_lengths
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
...
@@ -102,15 +102,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -102,15 +102,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
auto
argument
=
conv
.
MakeArgument
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
conv
.
MakeArgument
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
conv_param
.
G_
,
input_lengths
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
C_
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
input_strides
,
filter_lengths
,
weights_strides
,
weights_strides
,
output_lengths
,
output_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
2a09df0f
...
@@ -27,16 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
...
@@ -27,16 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in
,
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
void
*
p_wei
,
const
void
*
p_out
,
const
void
*
p_out
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
2a09df0f
...
@@ -784,16 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -784,16 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument
(
const
InDataType
*
p_in_grid
,
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*a_g_n_c_wis_strides*/
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*b_g_k_c_xs_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*e_g_n_k_wos_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*input_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*weights_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*output_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -813,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -813,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_
{
out_element_op
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
c_element_op_
{
in_element_op
},
Conv_G_
{
G
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]
},
Conv_N_
{
N
},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]
},
Conv_K_
{
K
},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]
},
Conv_C_
{
C
},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]
},
input_spatial_lengths_
{
input_spatial_lengths
},
input_spatial_lengths_
{},
filter_spatial_lengths_
{
filter_spatial_lengths
},
filter_spatial_lengths_
{},
output_spatial_lengths_
{
output_spatial_lengths
},
output_spatial_lengths_
{},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
k_batch_
{
split_k
}
{
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
N
,
Conv_N_
,
K
,
Conv_K_
,
C
,
C
onv_C_
,
input_spatial_lengths
,
input_spatial_lengths
_
,
filter_spatial_lengths
,
filter_spatial_lengths
_
,
output_spatial_lengths
,
output_spatial_lengths
_
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -857,19 +864,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -857,19 +864,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
N
*
K
*
Conv_N_
*
Conv_K_
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
std
::
accumulate
(
begin
(
output_spatial_lengths
),
end
(
output_spatial_lengths
),
end
(
output_spatial_lengths
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
N
*
C
*
Conv_N_
*
Conv_C_
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
std
::
accumulate
(
begin
(
input_spatial_lengths
),
end
(
input_spatial_lengths
),
end
(
input_spatial_lengths
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
),
index_t
{
1
},
index_t
{
1
},
...
@@ -905,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -905,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const
index_t
Conv_K_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
const
index_t
Conv_C_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
...
@@ -1111,19 +1118,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1111,19 +1118,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1136,16 +1140,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1136,16 +1140,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return
Argument
{
p_in_grid
,
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_wei_grid
,
p_out_grid
,
p_out_grid
,
G
,
a_g_n_c_wis_lengths
,
// input
N
,
a_g_n_c_wis_strides
,
K
,
b_g_k_c_xs_lengths
,
// weight
C
,
b_g_k_c_xs_strides
,
input_spatial_lengths
,
e_g_n_k_wos_lengths
,
// output
filter_spatial_lengths
,
e_g_n_k_wos_strides
,
output_spatial_lengths
,
input_strides
,
weights_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -1162,16 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1162,16 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
void
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1184,16 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1184,16 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
a_g_n_c_wis_lengths
,
// input
N
,
a_g_n_c_wis_strides
,
K
,
b_g_k_c_xs_lengths
,
// weight
C
,
b_g_k_c_xs_strides
,
input_spatial_lengths
,
e_g_n_k_wos_lengths
,
// output
filter_spatial_lengths
,
e_g_n_k_wos_strides
,
output_spatial_lengths
,
input_strides
,
weights_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
2a09df0f
...
@@ -1011,16 +1011,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1011,16 +1011,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Argument
(
const
InDataType
*
p_in_grid
,
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1045,28 +1041,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1045,28 +1041,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
a_element_op_
{
out_element_op
},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
c_element_op_
{
wei_element_op
},
Conv_G_
{
G
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]},
Conv_N_
{
N
},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]},
Conv_K_
{
K
},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]},
Conv_C_
{
C
},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]},
output_spatial_lengths_
{
output_spatial_lengths
},
input_spatial_lengths_
{},
filter_spatial_lengths_
{
filter_spatial_lengths
},
filter_spatial_lengths_
{},
output_spatial_lengths_
{},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
k_batch_
{
split_k
}
{
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
N
,
Conv_N_
,
K
,
Conv_K_
,
C
,
C
onv_C_
,
input_spatial_lengths
,
input_spatial_lengths
_
,
filter_spatial_lengths
,
filter_spatial_lengths
_
,
output_spatial_lengths
,
output_spatial_lengths
_
,
input
_strides
,
a_g_n_c_wis
_strides
,
weight
s_strides
,
b_g_k_c_x
s_strides
,
output
_strides
,
e_g_n_k_wos
_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -1081,12 +1089,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1081,12 +1089,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
// A/B/C Batch Stride
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
output
_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
e_g_n_k_wos
_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
input
_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
a_g_n_c_wis
_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
std
::
accumulate
(
begin
(
filter_spatial_lengths
_
),
end
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
_
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
...
@@ -1125,8 +1133,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1125,8 +1133,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const
index_t
Conv_N_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
const
index_t
Conv_C_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads_
;
...
@@ -1301,19 +1310,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1301,19 +1310,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1326,16 +1332,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1326,16 +1332,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return
Argument
{
p_in_grid
,
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_wei_grid
,
p_out_grid
,
p_out_grid
,
G
,
a_g_n_c_wis_lengths
,
// input
N
,
a_g_n_c_wis_strides
,
K
,
b_g_k_c_xs_lengths
,
// weight
C
,
b_g_k_c_xs_strides
,
input_spatial_lengths
,
e_g_n_k_wos_lengths
,
// output
filter_spatial_lengths
,
e_g_n_k_wos_strides
,
output_spatial_lengths
,
input_strides
,
weights_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -1354,16 +1356,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1354,16 +1356,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
void
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1376,16 +1374,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -1376,16 +1374,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
a_g_n_c_wis_lengths
,
// input
N
,
a_g_n_c_wis_strides
,
K
,
b_g_k_c_xs_lengths
,
// weight
C
,
b_g_k_c_xs_strides
,
input_spatial_lengths
,
e_g_n_k_wos_lengths
,
// output
filter_spatial_lengths
,
e_g_n_k_wos_strides
,
output_spatial_lengths
,
input_strides
,
weights_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
View file @
2a09df0f
...
@@ -136,9 +136,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
...
@@ -136,9 +136,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
// profile device Conv instances
// profile device Conv instances
bool
all_pass
=
true
;
bool
all_pass
=
true
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
filter_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
...
@@ -149,11 +149,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
...
@@ -149,11 +149,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
begin
(
input_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
begin
(
filter_lengths
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
begin
(
output_lengths
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
...
@@ -166,15 +166,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
...
@@ -166,15 +166,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
op_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
op_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
conv_param
.
G_
,
input_lengths
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
C_
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
input_strides
,
filter_lengths
,
weights_strides
,
weights_strides
,
output_lengths
,
output_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment