Commit ad25a036 authored by Khalique's avatar Khalique
Browse files

initial progress

parent 66bd7da6
...@@ -608,6 +608,36 @@ struct reshape ...@@ -608,6 +608,36 @@ struct reshape
int output_alias(const std::vector<shape>&) const { return 0; } int output_alias(const std::vector<shape>&) const { return 0; }
}; };
struct pad
{
std::vector<int64_t> pads;
float value = 0.0f;
std::string mode = "constant";
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.mode, "mode"), f(self.pads, "pads"), f(self.value, "value"));
}
std::string name() const { return "pad"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(idims.begin(), idims.end());
auto num_dims = rdims.size();
for(auto i = 0; i < num_dims; i++)
{
rdims[i] += pads[i] + pads[i + num_dims];
}
shape s{inputs.front().type(), rdims};
return s;
}
};
struct dot struct dot
{ {
float alpha = 1.0; float alpha = 1.0;
......
...@@ -81,6 +81,7 @@ struct onnx_parser ...@@ -81,6 +81,7 @@ struct onnx_parser
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat); add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Pad", &onnx_parser::parse_pad);
} }
template <class F> template <class F>
...@@ -525,6 +526,29 @@ struct onnx_parser ...@@ -525,6 +526,29 @@ struct onnx_parser
return prog.add_instruction(migraphx::op::transpose{perm}, args.front()); return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
} }
instruction_ref
parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::vector<int64_t> pads{};
float value = 0.0f;
if(contains(attributes, "pads"))
{
auto&& pad_vals = attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "mode"))
{
auto mode = attributes.at("mode").s();
if(mode != "constant")
MIGRAPHX_THROW("migraphx currently only supports constant padding");
}
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -297,6 +297,33 @@ struct cpu_contiguous ...@@ -297,6 +297,33 @@ struct cpu_contiguous
} }
}; };
struct cpu_pad
{
op::pad op;
std::string name() const { return "cpu::contiguous"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
assert(output_shape.standard());
argument result{output_shape};
result.visit([&](auto output) {
std::fill(output.begin(), output.end(), op.value);
});
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
std::vector<std::size_t> new_idx(idx.size());
std::transform(idx.begin(), idx.end(), op.pads.begin(), new_idx.begin(), [](auto i, auto j) {
return i + j;
});
output(new_idx.begin(), new_idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
}
};
struct cpu_concat struct cpu_concat
{ {
op::concat op; op::concat op;
...@@ -650,6 +677,7 @@ struct cpu_apply ...@@ -650,6 +677,7 @@ struct cpu_apply
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>(); apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment