Commit 60a0f286 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

changes to be able to parse the simple pytorch model

parent 2629e8f1
...@@ -59,8 +59,8 @@ struct broadcast ...@@ -59,8 +59,8 @@ struct broadcast
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)}; shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements()) // if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size"); // MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output; return output;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -55,7 +55,8 @@ struct convolution ...@@ -55,7 +55,8 @@ struct convolution
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
// dim num of input and attribute should match // dim num of input and attribute should match
auto input_size = inputs[0].lens().size(); auto in_lens = inputs[0].lens();
auto input_size = in_lens.size();
auto padding_size = padding.size(); auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2)) if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
...@@ -73,7 +74,7 @@ struct convolution ...@@ -73,7 +74,7 @@ struct convolution
if(input.lens().at(1) != (weights.lens().at(1) * group)) if(input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers"); MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{in_lens[0], weights.lens()[0]};
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
...@@ -82,7 +83,7 @@ struct convolution ...@@ -82,7 +83,7 @@ struct convolution
padding_factor = padding[i] + padding[i + kdims]; padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (in_lens[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) / padding_factor) /
stride[i] + stride[i] +
1))); 1)));
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/lifetime.hpp> #include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <iostream>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -30,6 +31,13 @@ struct reshape ...@@ -30,6 +31,13 @@ struct reshape
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
// input shape is dynamic, return dim directly
if (inputs.front().dynamic())
{
std::vector<std::size_t> rdims(dims.begin(), dims.end());
return {inputs.front().type(), rdims};
}
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
......
...@@ -20,7 +20,7 @@ struct shape_op ...@@ -20,7 +20,7 @@ struct shape_op
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto lens = args.front().get_shape().lens(); auto lens = args.front().get_shape().lens();
......
...@@ -125,6 +125,8 @@ struct shape ...@@ -125,6 +125,8 @@ struct shape
bool standard() const; bool standard() const;
/// Returns true if all strides are equal to 0 (scalar tensor) /// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const; bool scalar() const;
/// Return true if any dim is 0
bool dynamic() const;
shape normalize_standard() const; shape normalize_standard() const;
......
...@@ -280,7 +280,7 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod) ...@@ -280,7 +280,7 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
bool instruction::can_eval() const bool instruction::can_eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal" or op.name() == "shape")
{ {
return true; return true;
} }
...@@ -301,10 +301,18 @@ argument instruction::eval(bool check_eval) const ...@@ -301,10 +301,18 @@ argument instruction::eval(bool check_eval) const
{ {
return this->get_literal().get_argument(); return this->get_literal().get_argument();
} }
else if (op.name() == "shape")
{
argument arg{this->inputs().front()->get_shape()};
return normalized_operator().compute(result, {arg});
}
if(is_context_free(op)) if(is_context_free(op))
{ {
if(check_eval and not this->can_eval()) if(check_eval and not this->can_eval())
{
return {}; return {};
}
std::vector<argument> args; std::vector<argument> args;
std::transform(this->inputs().begin(), std::transform(this->inputs().begin(),
this->inputs().end(), this->inputs().end(),
......
...@@ -248,10 +248,10 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -248,10 +248,10 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) // if(map_input_dims.count(name) > 0)
{ // {
dims = map_input_dims.at(name); // dims = map_input_dims.at(name);
} // }
shape s = parse_type(input.type(), dims); shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
...@@ -262,6 +262,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -262,6 +262,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
std::cout << "node_op_type = " << node.op_type() << std::endl;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
...@@ -404,10 +405,10 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -404,10 +405,10 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
const std::vector<std::size_t>& input_dims) const const std::vector<std::size_t>& input_dims) const
{ {
shape::type_t shape_type = get_type(t.tensor_type().elem_type()); shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(!input_dims.empty()) // if(!input_dims.empty())
{ // {
return {shape_type, input_dims}; // return {shape_type, input_dims};
} // }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
...@@ -419,13 +420,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -419,13 +420,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{ {
if(static_cast<int>(d.dim_value()) <= 0) if(static_cast<int>(d.dim_value()) <= 0)
{ {
return default_dim_value; // return default_dim_value;
return 0;
} }
return d.dim_value(); return d.dim_value();
} }
else else
{ {
return default_dim_value; // return default_dim_value;
return 0;
} }
}); });
......
...@@ -272,6 +272,13 @@ bool shape::scalar() const ...@@ -272,6 +272,13 @@ bool shape::scalar() const
std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0; std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
} }
bool shape::dynamic() const
{
if (scalar()) return false;
const auto& lens = this->lens();
return std::find(lens.begin(), lens.end(), 0) != lens.end();
}
bool shape::standard() const { return impl->m_standard; } bool shape::standard() const { return impl->m_standard; }
shape shape::normalize_standard() const shape shape::normalize_standard() const
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace migraphx { namespace migraphx {
using index_int = std::uint32_t; using index_int = std::int32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT #define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment