Unverified Commit cb722cf9 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Enable read support for n-dimensional ops (#537)



* initial progress

* formatting

* add pooling changes

* formatting

* change eliminate_pad

* formatting

* rename var

* fomratting

* update op shape test and compute

* formatting

* revert conv constructor

* formatting

* change initializer

* formatting

* fix tidy

* change quant conv and shape check

* add tests and fixes

* formatting

* fix type

* fix conv test

* formatting

* add pooling and bn tests

* formatting

* add inconsistent attr tests

* fix padding issue

* formatting

* fix review comments, remove duplicate test

* formatting

* fix variable

* fix assert bug

* fix attr check

* remove std
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 93be5e2b
......@@ -40,8 +40,10 @@ void eliminate_pad::update_op(T,
if(!pad_op.symmetric())
return;
std::vector<int64_t> pads = pad_op.pads;
std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]), static_cast<size_t>(pads[3])};
auto kdims = input->get_shape().lens().size() - 2;
auto kdims_it = pad_op.pads.begin() + 2;
std::vector<size_t> new_pads(kdims_it, kdims_it + kdims);
T op = any_cast<T>(ins->get_operator());
op.padding = new_pads;
......
......@@ -71,6 +71,19 @@ struct check_shapes
return *this;
}
const check_shapes& min_ndims(std::size_t n) const
{
assert(begin != nullptr);
assert(end != nullptr);
if(begin != end)
{
if(begin->lens().size() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
......
......@@ -42,7 +42,7 @@ struct batch_norm_inference
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
check_shapes{inputs.data(), inputs.data() + 1, *this}.only_dims(4);
check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims();
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape().elements(
inputs.front().lens()[1]);
return inputs.front();
......
......@@ -19,9 +19,9 @@ namespace op {
struct convolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1};
int group = 1;
padding_mode_t padding_mode = default_;
......@@ -39,32 +39,33 @@ struct convolution
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
{
MIGRAPHX_THROW("convolution: inconsistent attribute sizes");
}
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
size_t kdims = input.lens().size() - 2;
if(input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
for(size_t i = 0; i < kdims; i++)
{
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) /
stride[i] +
1)));
}
return {t, output_lens};
}
};
......
......@@ -18,9 +18,9 @@ namespace op {
struct im2col
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
std::vector<std::size_t> padding{0, 0};
std::vector<std::size_t> stride{1, 1};
std::vector<std::size_t> dilation{1, 1};
padding_mode_t padding_mode = default_;
......
......@@ -51,6 +51,12 @@ struct pad
return s;
}
std::size_t pad_ndims() const
{
assert(pads.size() % 2 == 0);
return pads.size() / 2;
}
bool symmetric() const
{
std::size_t num_dims = pads.size() / 2;
......
......@@ -20,11 +20,11 @@ namespace op {
struct pooling
{
std::string mode = "average";
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
padding_mode_t padding_mode = default_;
std::string mode = "average";
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> lengths = {1, 1};
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -40,29 +40,31 @@ struct pooling
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).only_dims(4);
check_shapes{inputs, *this}.has(1);
if(not(padding.size() == stride.size() and padding.size() == lengths.size()))
{
MIGRAPHX_THROW("pooling: inconsistent attribute sizes");
}
const shape& input = inputs.at(0);
auto t = input.type();
assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[2] + 2 * padding[0] - lengths[0],
stride[0]) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[3] + 2 * padding[1] - lengths[1],
stride[1]) +
1)),
}};
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
for(size_t i = 0; i < kdims; i++)
{
assert(lengths[i] <= input_lens[i + 2] + 2 * padding[i]);
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input_lens[i + 2] + 2 * padding[i] - lengths[i],
stride[i]) +
1)));
}
return {t, output_lens};
}
};
......
......@@ -19,9 +19,9 @@ namespace op {
struct quant_convolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1};
padding_mode_t padding_mode = default_;
int group = 1;
......@@ -39,11 +39,16 @@ struct quant_convolution
std::string name() const { return "quant_convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
{
MIGRAPHX_THROW("quant_convolution: inconsistent attribute sizes");
}
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
size_t kdims = input.lens().size() - 2;
// all input type must be int8_type and output is float_type
if(t != shape::int8_type)
......@@ -52,23 +57,19 @@ struct quant_convolution
}
t = shape::int32_type;
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
for(size_t i = 0; i < kdims; i++)
{
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) /
stride[i] +
1)));
}
return {t, output_lens};
}
};
......
......@@ -20,16 +20,17 @@ inline void calculate_padding(int64_t idx,
int64_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
int64_t pad =
std::max(static_cast<int64_t>(0), (output_dim - 1) * stride + new_weight_dim - input_dim);
auto pad_ndims = pads.size() / 2;
if(is_same_upper)
{
pads[idx] = pad / 2;
pads[idx + 2] = pad - pad / 2;
pads[idx] = pad / 2;
pads[idx + pad_ndims] = pad - pad / 2;
}
else
{
pads[idx + 2] = pad / 2;
pads[idx] = pad - pad / 2;
pads[idx + pad_ndims] = pad / 2;
pads[idx] = pad - pad / 2;
}
}
......
......@@ -323,16 +323,34 @@ struct onnx_parser
Op& op,
float pad_val = 0)
{
if(padding[0] != padding[2] || padding[1] != padding[3])
bool asym_padding = false;
assert(padding.size() % 2 == 0);
size_t pad_ndims = padding.size() / 2;
auto left_pad_it = padding.begin();
auto right_pad_it = left_pad_it + pad_ndims;
for(size_t i = 0; i < pad_ndims; i++)
{
if(padding[i] != padding[i + pad_ndims])
{
asym_padding = true;
break;
}
}
if(asym_padding)
{
ins = prog.add_instruction(
op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
ins);
std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C
// add left pads
asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it);
// add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = prog.add_instruction(op::pad{asym_pads, pad_val}, ins);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
op.padding = std::vector<size_t>(left_pad_it, right_pad_it);
}
}
......@@ -427,11 +445,14 @@ struct onnx_parser
instruction_ref process_auto_pad_attribute(instruction_ref ins,
node_info info,
Op& op,
std::array<std::size_t, 2> k_lens,
std::array<std::size_t, 2> dilation,
std::vector<std::size_t> k_lens,
std::vector<std::size_t> dilation,
const std::vector<std::size_t>& in_lens,
float value = 0.0f)
{
size_t kdims = in_lens.size() - 2;
assert(k_lens.size() == kdims and dilation.size() == kdims);
if(!contains(info.attributes, "auto_pad"))
{
return ins;
......@@ -440,12 +461,20 @@ struct onnx_parser
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
std::vector<int64_t> padding(in_lens.size());
calculate_padding(
0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
calculate_padding(
1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
std::vector<int64_t> padding(2 * kdims);
for(size_t i = 0; i < padding.size() / 2; i++)
{
calculate_padding(i,
padding,
in_lens[i + 2],
op.stride[i],
dilation[i],
k_lens[i],
is_same_upper);
}
check_asym_padding(ins, padding, op, value);
}
......@@ -529,6 +558,35 @@ struct onnx_parser
return input;
}
void check_attr_sizes(size_t kdims, size_t attr_size, const std::string& error_msg)
{
if(kdims != attr_size)
{
MIGRAPHX_THROW(error_msg + " k-dims: " + to_string(kdims) +
" attribute size: " + to_string(attr_size));
}
}
template <class Op>
void recalc_conv_attributes(Op& op, size_t kdims)
{
if(op.padding.size() != kdims)
{
op.padding.resize(kdims);
std::fill_n(op.padding.begin(), kdims, 0);
}
if(op.stride.size() != kdims)
{
op.stride.resize(kdims);
std::fill_n(op.stride.begin(), kdims, 1);
}
if(op.dilation.size() != kdims)
{
op.dilation.resize(kdims);
std::fill_n(op.dilation.begin(), kdims, 1);
}
}
template <class Op>
instruction_ref
parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
......@@ -536,6 +594,10 @@ struct onnx_parser
Op op;
auto l0 = args[0];
auto weights = args[1];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
std::vector<int64_t> padding;
if(contains(info.attributes, "pads"))
{
......@@ -548,44 +610,28 @@ struct onnx_parser
"PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
}
}
op.padding.clear();
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
}
check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
check_asym_padding(l0, padding, op);
}
if(contains(info.attributes, "strides"))
{
copy(info.attributes["strides"].ints(), op.stride.begin());
op.stride.clear();
copy(info.attributes["strides"].ints(), std::back_inserter(op.stride));
check_attr_sizes(kdims, op.stride.size(), "PARSE_CONV: inconsistent strides");
}
if(contains(info.attributes, "dilations"))
{
copy(info.attributes["dilations"].ints(), op.dilation.begin());
op.dilation.clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(op.dilation));
check_attr_sizes(kdims, op.dilation.size(), "PARSE_CONV: inconsistent dilations");
}
if(contains(info.attributes, "auto_pad"))
{
auto s = info.attributes["auto_pad"].s();
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
auto input_dims = l0->get_shape().lens();
padding.resize(input_dims.size());
calculate_padding(
0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(
1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);
check_asym_padding(l0, padding, op);
}
auto weight_lens = weights->get_shape().lens();
auto in_lens = args[0]->get_shape().lens();
auto weight_lens = args[1]->get_shape().lens();
std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
}
if(contains(info.attributes, "group"))
......@@ -593,6 +639,8 @@ struct onnx_parser
op.group = parse_value(info.attributes.at("group")).at<int>();
}
recalc_conv_attributes(op, kdims);
auto l1 = prog.add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1);
}
......@@ -705,11 +753,14 @@ struct onnx_parser
parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
{
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
auto l0 = args[0];
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
if(starts_with(name, "Global"))
{
auto lens = args.front()->get_shape().lens();
op.lengths = {lens[2], lens[3]};
op.lengths = std::vector<size_t>(in_lens.begin() + 2, in_lens.end());
}
if(contains(info.attributes, "pads"))
......@@ -723,45 +774,60 @@ struct onnx_parser
"PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
}
}
std::vector<std::int64_t> padding;
op.padding.clear();
std::vector<int64_t> padding;
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
}
check_attr_sizes(kdims, padding.size() / 2, "PARSE_POOLING: inconsistent paddings");
float pad_val = 0;
if(op.mode == "max")
pad_val = std::numeric_limits<float>::lowest();
check_asym_padding(l0, padding, op, pad_val);
in_lens = l0->get_shape().lens();
}
if(contains(info.attributes, "strides"))
{
copy(info.attributes["strides"].ints(), op.stride.begin());
op.stride.clear();
copy(info.attributes["strides"].ints(), std::back_inserter(op.stride));
check_attr_sizes(kdims, op.stride.size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
op.lengths.clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(op.lengths));
check_attr_sizes(kdims, op.lengths.size(), "PARSE_POOLING: inconsistent lengths");
}
if(contains(info.attributes, "auto_pad"))
{
auto s = info.attributes["auto_pad"].s();
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
auto in_lens = args[0]->get_shape().lens();
float val = 0.0f;
op.padding.clear();
float val = 0.0f;
// MaxPool
if(op.mode == "max")
{
val = std::numeric_limits<float>::lowest();
}
l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
in_lens = l0->get_shape().lens();
}
if(op.padding.size() != kdims)
{
op.padding.resize(kdims);
std::fill_n(op.padding.begin(), kdims, 0);
}
if(op.stride.size() != kdims)
{
op.stride.resize(kdims);
std::fill_n(op.stride.begin(), kdims, 1);
}
for(size_t i = 0; i < kdims; i++)
{
if(op.lengths[i] > in_lens[i + 2] + 2 * op.padding[i])
MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
}
return prog.add_instruction(op, l0);
......
......@@ -163,10 +163,10 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto op = conv.op;
// Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
wei.lens()[3] != 3 and op.stride == make_array<size_t>(1, 1))
wei.lens()[3] != 3 and contains({{1, 1}}, op.stride))
return false;
return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and
contains({{0, 0}, {1, 1}}, op.stride) and op.dilation == make_array<size_t>(1, 1);
contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation);
}
struct hip_triadd
......@@ -600,7 +600,7 @@ struct miopen_conv_bias
}
miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b)
: op(c), f(input)
: op(std::move(c)), f(input)
{
conv = f.create_conv(op, weights);
bias = f.create_bias(b);
......@@ -649,7 +649,7 @@ struct miopen_conv_bias_relu
const shape& input,
const shape& weights,
const shape& b)
: op(c), f(input)
: op(std::move(c)), f(input)
{
conv = f.create_conv(op, weights);
bias = f.create_bias(b);
......
......@@ -7,7 +7,7 @@ namespace gpu {
shape miopen_pooling::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
check_shapes{inputs, *this}.has(2).standard().only_dims(4);
return op.compute_shape({inputs.at(0)});
}
argument miopen_pooling::compute(context& ctx,
......
......@@ -408,9 +408,9 @@ TEST_CASE(im2col_3x3_no_pad_identity_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {3, 3};
std::array<std::size_t, 2> padding{{0, 0}};
std::array<std::size_t, 2> stride{{1, 1}};
std::array<std::size_t, 2> dilation{{1, 1}};
std::vector<std::size_t> padding{0, 0};
std::vector<std::size_t> stride{1, 1};
std::vector<std::size_t> dilation{1, 1};
std::size_t channels = 1;
std::vector<int32_t> weights(channels * f[0] * f[1]);
......@@ -437,9 +437,9 @@ TEST_CASE(im2col_3x3_no_pad_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {4, 4};
std::array<std::size_t, 2> padding{{0, 0}};
std::array<std::size_t, 2> stride{{1, 1}};
std::array<std::size_t, 2> dilation{{1, 1}};
std::vector<std::size_t> padding{0, 0};
std::vector<std::size_t> stride{1, 1};
std::vector<std::size_t> dilation{1, 1};
std::size_t channels = 1;
std::vector<int32_t> weights(channels * f[0] * f[1]);
......@@ -469,9 +469,9 @@ TEST_CASE(im2col_3x3_stride_2_no_pad_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {6, 6};
std::array<std::size_t, 2> padding{{0, 0}};
std::array<std::size_t, 2> stride{{2, 2}};
std::array<std::size_t, 2> dilation{{1, 1}};
std::vector<std::size_t> padding{0, 0};
std::vector<std::size_t> stride{2, 2};
std::vector<std::size_t> dilation{1, 1};
std::size_t channels = 1;
std::vector<int32_t> weights(channels * f[0] * f[1]);
......@@ -502,9 +502,9 @@ TEST_CASE(im2col_3x3_with_padding_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {2, 2};
std::array<std::size_t, 2> padding{{1, 1}};
std::array<std::size_t, 2> stride{{1, 1}};
std::array<std::size_t, 2> dilation{{1, 1}};
std::vector<std::size_t> padding{1, 1};
std::vector<std::size_t> stride{1, 1};
std::vector<std::size_t> dilation{1, 1};
std::size_t channels = 1;
std::vector<int32_t> weights(channels * f[0] * f[1]);
......@@ -580,9 +580,9 @@ TEST_CASE(im2col_3x3_with_channels_identity_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {3, 3};
std::array<std::size_t, 2> padding{{0, 0}};
std::array<std::size_t, 2> stride{{1, 1}};
std::array<std::size_t, 2> dilation{{1, 1}};
std::vector<std::size_t> padding{0, 0};
std::vector<std::size_t> stride{1, 1};
std::vector<std::size_t> dilation{1, 1};
std::size_t channels = 2;
std::vector<int32_t> weights(channels * f[0] * f[1]);
......
averagepool_1d_test:q
(
01" AveragePool*
kernel_shape@averagepool_1d_testZ
0



b
1



B
\ No newline at end of file
averagepool_3d_test:
,
01" AveragePool*
kernel_shape@@@averagepool_3d_testZ
0





b
1





B
\ No newline at end of file
batchnorm_1d_test:
M
0
1
2
3
45"BatchNormalization*
epsilon75*
momentumfff?batchnorm_1d_testZ
0



Z
1

Z
2

Z
3

Z
4

b
5



B
\ No newline at end of file
batchnorm_3d_test:
M
0
1
2
3
45"BatchNormalization*
epsilon75*
momentumfff?batchnorm_3d_testZ
0





Z
1

Z
2

Z
3

Z
4

b
5





B
\ No newline at end of file
 conv_1d_test:j

0
12"Conv conv_1d_testZ
0



Z
1



b
2



B
\ No newline at end of file
 conv_3d_test:‚

0
12"Conv conv_3d_testZ
0





Z
1





b
2





B
\ No newline at end of file
conv_attr_fail_test:
!
0
12"Conv*
strides@@conv_attr_fail_testZ
0



Z
1



b
2



B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment