Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5076982b
Commit
5076982b
authored
Apr 29, 2022
by
Chao Liu
Browse files
format
parent
8c03672b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
103 additions
and
60 deletions
+103
-60
include/ck/tensor_description/tensor_descriptor.hpp
include/ck/tensor_description/tensor_descriptor.hpp
+2
-3
include/ck/tensor_description/tensor_descriptor_helper.hpp
include/ck/tensor_description/tensor_descriptor_helper.hpp
+6
-8
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+2
-2
library/include/ck/library/host_tensor/host_reduction.hpp
library/include/ck/library/host_tensor/host_reduction.hpp
+5
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
...e_tensor_operation/cpu/reference_conv_backward_weight.hpp
+13
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+34
-14
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
...nsor_operation/cpu/reference_conv_fwd_bias_activation.hpp
+12
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
..._operation/cpu/reference_conv_fwd_bias_activation_add.hpp
+12
-4
library/src/utility/conv_fwd_util.cpp
library/src/utility/conv_fwd_util.cpp
+17
-18
No files found.
include/ck/tensor_description/tensor_descriptor.hpp
View file @
5076982b
...
...
@@ -391,9 +391,8 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
,
real_size
};
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
,
real_size
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
...
...
include/ck/tensor_description/tensor_descriptor_helper.hpp
View file @
5076982b
...
...
@@ -72,8 +72,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
const
auto
element_space_size
=
f
(
f
,
Number
<
0
>
{},
Number
<
1
>
{});
#else
const
auto
real_size
=
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
integral_constant
<
std
::
size_t
,
1ul
>
{});
const
auto
real_size
=
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
integral_constant
<
std
::
size_t
,
1ul
>
{});
const
auto
element_space_size
=
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
Number
<
1
>
{});
...
...
@@ -84,9 +84,8 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
,
real_size
};
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
,
real_size
};
}
// Lengths... can be:
...
...
@@ -116,9 +115,8 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
,
real_size
};
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
,
real_size
};
}
template
<
typename
...
Lengths
,
typename
Align
>
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
5076982b
...
...
@@ -385,8 +385,8 @@ struct DeviceBatchedGemmXdl
c_grid_desc_m_n_
{
DeviceBatchedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
compute_ptr_offset_of_batch_
{
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
(),
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
(),
c_grid_desc_m_n_
.
GetElementSpaceSize
()},
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
(),
c_grid_desc_m_n_
.
GetElementSpaceSize
()},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
...
...
library/include/ck/library/host_tensor/host_reduction.hpp
View file @
5076982b
...
...
@@ -211,7 +211,8 @@ struct ReductionHost
AccDataType
accuVal
=
ReduceOpZeroVal
<
AccDataType
,
ReduceOpId
>
();
IndexDataType
accuIndex
=
0
;
for
(
IndexDataType
i
=
0
;
i
<
ck
::
type_convert
<
IndexDataType
>
(
reduce_dim_indexes
.
size
());
i
++
)
for
(
IndexDataType
i
=
0
;
i
<
ck
::
type_convert
<
IndexDataType
>
(
reduce_dim_indexes
.
size
());
i
++
)
{
auto
offset_reduce
=
get_offset_from_index
<
NumReduceDim
>
(
reduceStrides
,
reduce_dim_indexes
[
i
]);
...
...
@@ -246,7 +247,9 @@ struct ReductionHost
auto
offset_invariant
=
get_offset_from_index
<
NumInvariantDim
>
(
invariantStrides
,
invariant_index
);
for
(
IndexDataType
i
=
0
;
i
<
ck
::
type_convert
<
IndexDataType
>
(
reduce_dim_indexes
.
size
());
i
++
)
for
(
IndexDataType
i
=
0
;
i
<
ck
::
type_convert
<
IndexDataType
>
(
reduce_dim_indexes
.
size
());
i
++
)
{
auto
offset_reduce
=
get_offset_from_index
<
NumReduceDim
>
(
reduceStrides
,
reduce_dim_indexes
[
i
]);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
View file @
5076982b
...
...
@@ -70,18 +70,26 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
int
n
=
0
;
n
<
ck
::
type_convert
<
int
>
(
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
0
]);
++
n
)
for
(
int
n
=
0
;
n
<
ck
::
type_convert
<
int
>
(
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
0
]);
++
n
)
{
for
(
int
ho
=
0
;
ho
<
ck
::
type_convert
<
int
>
(
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
2
]);
++
ho
)
for
(
int
ho
=
0
;
ho
<
ck
::
type_convert
<
int
>
(
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
2
]);
++
ho
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
I0
]
+
y
*
arg
.
conv_dilations_
[
I0
]
-
arg
.
in_left_pads_
[
I0
];
for
(
int
wo
=
0
;
wo
<
ck
::
type_convert
<
int
>
(
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
3
]);
++
wo
)
for
(
int
wo
=
0
;
wo
<
ck
::
type_convert
<
int
>
(
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
3
]);
++
wo
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
I1
]
+
x
*
arg
.
conv_dilations_
[
I1
]
-
arg
.
in_left_pads_
[
I1
];
if
(
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
]))
if
(
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
]))
{
float
v_out
;
float
v_in
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
5076982b
...
...
@@ -88,13 +88,16 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_ncw
=
[
&
](
auto
n
,
auto
k
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
{
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
x
)
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
0
]
+
x
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
if
(
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]))
if
(
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]))
{
float
v_in
;
float
v_wei
;
...
...
@@ -128,17 +131,23 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
{
for
(
int
y
=
0
;
y
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
y
)
for
(
int
y
=
0
;
y
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
0
]
+
y
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]);
++
x
)
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]);
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
1
]
+
x
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
if
(
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
if
(
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]))
{
float
v_in
;
...
...
@@ -174,23 +183,34 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
{
for
(
int
z
=
0
;
z
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
z
)
for
(
int
z
=
0
;
z
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
]);
++
z
)
{
int
di
=
d_o
*
arg
.
conv_strides_
[
0
]
+
z
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
y
=
0
;
y
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]);
++
y
)
for
(
int
y
=
0
;
y
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
]);
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
1
]
+
y
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
]);
++
x
)
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
]);
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
2
]
+
x
*
arg
.
conv_dilations_
[
2
]
-
arg
.
in_left_pads_
[
2
];
if
(
di
>=
0
&&
di
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
&&
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
&&
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
4
]))
if
(
di
>=
0
&&
di
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
&&
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
&&
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
input_
.
mDesc
.
GetLengths
()[
4
]))
{
float
v_in
;
float
v_wei
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
View file @
5076982b
...
...
@@ -73,17 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
{
for
(
int
y
=
0
;
y
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
]);
++
y
)
for
(
int
y
=
0
;
y
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
]);
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
0
]
+
y
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
]);
++
x
)
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
]);
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
1
]
+
x
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
if
(
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
if
(
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
]))
{
float
v_in
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
View file @
5076982b
...
...
@@ -76,17 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
for
(
int
c
=
0
;
c
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
]);
++
c
)
{
for
(
int
y
=
0
;
y
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
]);
++
y
)
for
(
int
y
=
0
;
y
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
]);
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
0
]
+
y
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
]);
++
x
)
for
(
int
x
=
0
;
x
<
ck
::
type_convert
<
int
>
(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
]);
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
1
]
+
x
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
if
(
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
if
(
hi
>=
0
&&
hi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
])
&&
wi
>=
0
&&
wi
<
ck
::
type_convert
<
int
>
(
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
]))
{
float
v_in
;
...
...
library/src/utility/conv_fwd_util.cpp
View file @
5076982b
...
...
@@ -37,16 +37,16 @@ std::size_t get_flops(ck::index_t N,
}
ConvParams
::
ConvParams
()
:
num_dim_spatial
(
2
),
N
(
128
),
K
(
256
),
C
(
192
),
filter_spatial_lengths
(
2
,
3
),
input_spatial_lengths
(
2
,
71
),
conv_filter_strides
(
2
,
2
),
conv_filter_dilations
(
2
,
1
),
input_left_pads
(
2
,
1
),
input_right_pads
(
2
,
1
)
:
num_dim_spatial
(
2
),
N
(
128
),
K
(
256
),
C
(
192
),
filter_spatial_lengths
(
2
,
3
),
input_spatial_lengths
(
2
,
71
),
conv_filter_strides
(
2
,
2
),
conv_filter_dilations
(
2
,
1
),
input_left_pads
(
2
,
1
),
input_right_pads
(
2
,
1
)
{
}
...
...
@@ -78,9 +78,9 @@ ConvParams::ConvParams(ck::index_t n_dim,
ck
::
type_convert
<
ck
::
index_t
>
(
input_left_pads
.
size
())
!=
num_dim_spatial
||
ck
::
type_convert
<
ck
::
index_t
>
(
input_right_pads
.
size
())
!=
num_dim_spatial
)
{
throw
(
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!"
));
throw
(
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!"
));
}
}
...
...
@@ -93,9 +93,9 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
ck
::
type_convert
<
ck
::
index_t
>
(
input_left_pads
.
size
())
!=
num_dim_spatial
||
ck
::
type_convert
<
ck
::
index_t
>
(
input_right_pads
.
size
())
!=
num_dim_spatial
)
{
throw
(
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!"
));
throw
(
std
::
runtime_error
(
"ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!"
));
}
std
::
vector
<
ck
::
index_t
>
out_spatial_len
(
num_dim_spatial
,
0
);
...
...
@@ -103,8 +103,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
idx_eff
=
(
filter_spatial_lengths
[
i
]
-
1
)
*
conv_filter_dilations
[
i
]
+
1
;
const
ck
::
index_t
idx_eff
=
(
filter_spatial_lengths
[
i
]
-
1
)
*
conv_filter_dilations
[
i
]
+
1
;
out_spatial_len
[
i
]
=
(
input_spatial_lengths
[
i
]
+
input_left_pads
[
i
]
+
input_right_pads
[
i
]
-
idx_eff
)
/
conv_filter_strides
[
i
]
+
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment