Commit 7e23d5c4 authored by Scott Thornton's avatar Scott Thornton
Browse files

Deleted memcpy for Paul. Added fix in Reshape for cases with -1 dimension.

parent 8c2d316e
......@@ -341,11 +341,25 @@ struct reshape
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPH_THROW("Dimensions for reshape can only have one -1 dim");
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
}
if(n_neg_dims > 0)
{
size_t missing_dim =
-inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
if(dims.back() == -1)
{
rdims.pop_back();
......
......@@ -320,10 +320,7 @@ struct onnx_parser
std::string s = t.raw_data();
if(t.data_type() == onnx::TensorProto::FLOAT)
{
std::vector<float> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::float_type, dims}, raw};
return literal{{shape::float_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT8)
{
......@@ -331,38 +328,23 @@ struct onnx_parser
}
else if(t.data_type() == onnx::TensorProto::INT8)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT16)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT16)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT32)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT64)
{
std::vector<int64_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int64_type, dims}, raw};
return literal{{shape::int64_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::STRING)
{
......@@ -370,10 +352,7 @@ struct onnx_parser
}
else if(t.data_type() == onnx::TensorProto::BOOL)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::FLOAT16)
{
......@@ -381,10 +360,7 @@ struct onnx_parser
}
else if(t.data_type() == onnx::TensorProto::DOUBLE)
{
std::vector<double> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::double_type, dims}, raw};
return literal{{shape::double_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT32)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment