Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8a8dca0a
Commit
8a8dca0a
authored
Aug 05, 2023
by
Bartlomiej Kocot
Browse files
Fixes for examples
parent
83360328
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
62 additions
and
105 deletions
+62
-105
client_example/11_grouped_conv_bwd_weight/common.hpp
client_example/11_grouped_conv_bwd_weight/common.hpp
+23
-46
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
+6
-10
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
+6
-10
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
+6
-10
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
...rouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
+6
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+6
-6
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
...d_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
+9
-13
No files found.
client_example/11_grouped_conv_bwd_weight/common.hpp
View file @
8a8dca0a
...
...
@@ -32,17 +32,14 @@ struct SimpleDeviceMem
};
template
<
ck
::
index_t
NumDimSpatial
>
std
::
size_t
GetFlops
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_lengths
,
std
::
size_t
GetFlops
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_lengths
)
{
constexpr
index_t
spatial_offset
=
3
;
constexpr
ck
::
index_t
spatial_offset
=
3
;
const
auto
C
=
filter_lengths
[
2
];
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return
static_cast
<
std
::
size_t
>
(
2
)
*
G
*
N
*
K
*
C
*
std
::
accumulate
(
std
::
begin
(
output_lengths
)
+
spatial_offset
,
return
static_cast
<
std
::
size_t
>
(
2
)
*
C
*
std
::
accumulate
(
std
::
begin
(
output_lengths
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
())
*
...
...
@@ -53,45 +50,30 @@ std::size_t GetFlops(ck::index_t G,
}
template
<
typename
InDataType
,
ck
::
index_t
NumDimSpatial
>
std
::
size_t
GetInputByte
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_lengths
)
std
::
size_t
GetInputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
input_lengths
)
{
constexpr
index_t
spatial_offset
=
3
;
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
(
G
*
N
*
C
*
std
::
accumulate
(
std
::
begin
(
input_lengths
)
+
spatial_offset
,
return
sizeof
(
InDataType
)
*
(
std
::
accumulate
(
std
::
begin
(
input_lengths
),
std
::
end
(
input_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
()));
}
template
<
typename
WeiDataType
,
ck
::
index_t
NumDimSpatial
>
std
::
size_t
GetWeightByte
(
ck
::
index_t
G
,
ck
::
index_t
K
,
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_lengths
)
std
::
size_t
GetWeightByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
filter_lengths
)
{
constexpr
index_t
spatial_offset
=
3
;
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
(
G
*
K
*
C
*
std
::
accumulate
(
std
::
begin
(
filter_lengths
)
+
spatial_offset
,
return
sizeof
(
WeiDataType
)
*
(
std
::
accumulate
(
std
::
begin
(
filter_lengths
),
std
::
end
(
filter_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
()));
}
template
<
typename
OutDataType
,
ck
::
index_t
NumDimSpatial
>
std
::
size_t
GetOutputByte
(
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_lengths
)
std
::
size_t
GetOutputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
output_lengths
)
{
constexpr
index_t
spatial_offset
=
3
;
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
(
G
*
N
*
K
*
std
::
accumulate
(
std
::
begin
(
output_lengths
)
+
spatial_offset
,
return
sizeof
(
OutDataType
)
*
(
std
::
accumulate
(
std
::
begin
(
output_lengths
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
...
...
@@ -105,15 +87,11 @@ template <ck::index_t NumDimSpatial,
typename
WeiLayout
,
typename
OutLayout
>
bool
run_grouped_conv_bwd_weight
(
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
input_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
filter_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
filter_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
weights_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>&
conv_filter_dilations
,
...
...
@@ -122,9 +100,9 @@ bool run_grouped_conv_bwd_weight(
{
ck
::
index_t
split_k
=
2
;
SimpleDeviceMem
in
(
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
input_lengths
));
SimpleDeviceMem
wei
(
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
filter_lengths
));
SimpleDeviceMem
out
(
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
output_lengths
));
SimpleDeviceMem
in
(
GetInputByte
<
InDataType
,
NumDimSpatial
+
3
>
(
input_lengths
));
SimpleDeviceMem
wei
(
GetWeightByte
<
WeiDataType
,
NumDimSpatial
+
3
>
(
filter_lengths
));
SimpleDeviceMem
out
(
GetOutputByte
<
OutDataType
,
NumDimSpatial
+
3
>
(
output_lengths
));
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
...
...
@@ -148,9 +126,9 @@ bool run_grouped_conv_bwd_weight(
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
N
um
DimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
N
um
DimSpatial
+
3
>
a_g_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
N
um
DimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
...
...
@@ -182,11 +160,10 @@ bool run_grouped_conv_bwd_weight(
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
GetFlops
<
NumDimSpatial
>
(
G
,
N
,
K
,
C
,
output_lengths
,
filter_lengths
);
std
::
size_t
num_bytes
=
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
G
,
N
,
C
,
input_lengths
)
+
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
G
,
K
,
C
,
filter_lengths
)
+
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
G
,
N
,
K
,
output_lengths
);
std
::
size_t
flop
=
GetFlops
<
NumDimSpatial
+
3
>
(
output_lengths
,
filter_lengths
);
std
::
size_t
num_bytes
=
GetInputByte
<
InDataType
,
NumDimSpatial
+
3
>
(
input_lengths
)
+
GetWeightByte
<
WeiDataType
,
NumDimSpatial
+
3
>
(
filter_lengths
)
+
GetOutputByte
<
OutDataType
,
NumDimSpatial
+
3
>
(
output_lengths
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp
View file @
8a8dca0a
...
...
@@ -22,9 +22,9 @@ static constexpr ck::index_t C = 192;
static
constexpr
ck
::
index_t
X
=
3
;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_
spatial_
lengths
{
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_
spatial_
lengths
{
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_
spatial_
lengths
{
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
filter_lengths
{
G
,
K
,
C
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Wi
*
C
,
Wi
*
C
,
1
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
X
*
C
,
X
*
C
,
1
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
N
*
Wo
*
K
,
Wo
*
K
,
1
,
K
};
...
...
@@ -41,15 +41,11 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
OutLayout
>
(
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp
View file @
8a8dca0a
...
...
@@ -25,9 +25,9 @@ static constexpr ck::index_t Hi = 28;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_
spatial_
lengths
{
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_
spatial_
lengths
{
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_
spatial_
lengths
{
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
filter_lengths
{
G
,
K
,
C
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Hi
*
Wi
*
C
,
Hi
*
Wi
*
C
,
1
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
...
...
@@ -47,15 +47,11 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
OutLayout
>
(
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp
View file @
8a8dca0a
...
...
@@ -28,9 +28,9 @@ static constexpr ck::index_t Wi = 3;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_
spatial_
lengths
{
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_
spatial_
lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_
spatial_
lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
filter_lengths
{
G
,
K
,
C
,
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
1
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
...
...
@@ -50,15 +50,11 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
OutLayout
>
(
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp
View file @
8a8dca0a
...
...
@@ -28,9 +28,9 @@ static constexpr ck::index_t Wi = 3;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_
spatial_
lengths
{
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_
spatial_
lengths
{
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
output_
spatial_
lengths
{
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Di
,
Hi
,
Wi
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
filter_lengths
{
G
,
K
,
C
,
Z
,
Y
,
X
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Do
,
Ho
,
Wo
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
N
*
Di
*
Hi
*
Wi
*
C
,
Di
*
Hi
*
Wi
*
C
,
1
,
Hi
*
Wi
*
C
,
Wi
*
C
,
C
};
static
constexpr
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
...
...
@@ -50,15 +50,11 @@ int main()
OutDataType
,
InLayout
,
WeiLayout
,
OutLayout
>
(
G
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
OutLayout
>
(
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
8a8dca0a
...
...
@@ -865,20 +865,20 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
Conv_N_
*
Conv_K_
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
end
(
output_spatial_lengths
),
std
::
accumulate
(
begin
(
output_spatial_lengths
_
),
end
(
output_spatial_lengths
_
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
Conv_N_
*
Conv_C_
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
end
(
input_spatial_lengths
),
std
::
accumulate
(
begin
(
input_spatial_lengths
_
),
end
(
input_spatial_lengths
_
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
),
std
::
accumulate
(
begin
(
filter_spatial_lengths
_
),
end
(
filter_spatial_lengths
_
),
index_t
{
1
},
std
::
multiplies
<>
{});
}
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
View file @
8a8dca0a
...
...
@@ -70,9 +70,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
conv_param
);
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
filter_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
...
...
@@ -83,11 +83,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
begin
(
input_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
begin
(
filter_lengths
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
begin
(
output_lengths
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
...
...
@@ -99,15 +99,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto
argument
=
conv
.
MakeArgument
(
nullptr
,
nullptr
,
nullptr
,
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
C_
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment