Commit 60a0f286 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

changes to be able to parse the simple pytorch model

parent 2629e8f1
......@@ -59,8 +59,8 @@ struct broadcast
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
// if(output.elements() < input.elements())
// MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output;
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -55,7 +55,8 @@ struct convolution
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
// dim num of input and attribute should match
auto input_size = inputs[0].lens().size();
auto in_lens = inputs[0].lens();
auto input_size = in_lens.size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{
......@@ -73,7 +74,7 @@ struct convolution
if(input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
std::vector<size_t> output_lens{in_lens[0], weights.lens()[0]};
for(size_t i = 0; i < kdims; i++)
{
......@@ -82,7 +83,7 @@ struct convolution
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
(in_lens[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
stride[i] +
1)));
......
......@@ -11,6 +11,7 @@
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -30,6 +31,13 @@ struct reshape
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
// input shape is dynamic, return dim directly
if (inputs.front().dynamic())
{
std::vector<std::size_t> rdims(dims.begin(), dims.end());
return {inputs.front().type(), rdims};
}
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
......
......@@ -20,7 +20,7 @@ struct shape_op
return {shape::int64_type, lens};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto lens = args.front().get_shape().lens();
......
......@@ -125,6 +125,8 @@ struct shape
bool standard() const;
/// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const;
/// Return true if any dim is 0
bool dynamic() const;
shape normalize_standard() const;
......
......@@ -280,7 +280,7 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
bool instruction::can_eval() const
{
if(op.name() == "@literal")
if(op.name() == "@literal" or op.name() == "shape")
{
return true;
}
......@@ -301,10 +301,18 @@ argument instruction::eval(bool check_eval) const
{
return this->get_literal().get_argument();
}
else if (op.name() == "shape")
{
argument arg{this->inputs().front()->get_shape()};
return normalized_operator().compute(result, {arg});
}
if(is_context_free(op))
{
if(check_eval and not this->can_eval())
{
return {};
}
std::vector<argument> args;
std::transform(this->inputs().begin(),
this->inputs().end(),
......
......@@ -248,10 +248,10 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
}
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
}
// if(map_input_dims.count(name) > 0)
// {
// dims = map_input_dims.at(name);
// }
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s);
......@@ -262,6 +262,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
for(auto&& node : graph.node())
{
std::cout << "node_op_type = " << node.op_type() << std::endl;
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
......@@ -404,10 +405,10 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
const std::vector<std::size_t>& input_dims) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(!input_dims.empty())
{
return {shape_type, input_dims};
}
// if(!input_dims.empty())
// {
// return {shape_type, input_dims};
// }
std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
......@@ -419,13 +420,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{
if(static_cast<int>(d.dim_value()) <= 0)
{
return default_dim_value;
// return default_dim_value;
return 0;
}
return d.dim_value();
}
else
{
return default_dim_value;
// return default_dim_value;
return 0;
}
});
......
......@@ -272,6 +272,13 @@ bool shape::scalar() const
std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
}
bool shape::dynamic() const
{
if (scalar()) return false;
const auto& lens = this->lens();
return std::find(lens.begin(), lens.end(), 0) != lens.end();
}
bool shape::standard() const { return impl->m_standard; }
shape shape::normalize_standard() const
......
......@@ -5,7 +5,7 @@
namespace migraphx {
using index_int = std::uint32_t;
using index_int = std::int32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment