Commit 7e23d5c4 authored by Scott Thornton's avatar Scott Thornton
Browse files

Deleted memcpy for Paul. Added fix in Reshape for cases with -1 dimension.

parent 8c2d316e
...@@ -341,11 +341,25 @@ struct reshape ...@@ -341,11 +341,25 @@ struct reshape
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPH_THROW("Dimensions for reshape can only have one -1 dim");
for(std::size_t i = 0; i < dims.size(); i++) for(std::size_t i = 0; i < dims.size(); i++)
{ {
if(dims[i] == 0) if(dims[i] == 0)
rdims[i] = idims[i]; rdims[i] = idims[i];
} }
if(n_neg_dims > 0)
{
size_t missing_dim =
-inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
if(dims.back() == -1) if(dims.back() == -1)
{ {
rdims.pop_back(); rdims.pop_back();
......
...@@ -320,10 +320,7 @@ struct onnx_parser ...@@ -320,10 +320,7 @@ struct onnx_parser
std::string s = t.raw_data(); std::string s = t.raw_data();
if(t.data_type() == onnx::TensorProto::FLOAT) if(t.data_type() == onnx::TensorProto::FLOAT)
{ {
std::vector<float> raw( return literal{{shape::float_type, dims}, s.data()};
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::float_type, dims}, raw};
} }
else if(t.data_type() == onnx::TensorProto::UINT8) else if(t.data_type() == onnx::TensorProto::UINT8)
{ {
...@@ -331,38 +328,23 @@ struct onnx_parser ...@@ -331,38 +328,23 @@ struct onnx_parser
} }
else if(t.data_type() == onnx::TensorProto::INT8) else if(t.data_type() == onnx::TensorProto::INT8)
{ {
std::vector<int32_t> raw( return literal{{shape::int32_type, dims}, s.data()};
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
} }
else if(t.data_type() == onnx::TensorProto::UINT16) else if(t.data_type() == onnx::TensorProto::UINT16)
{ {
std::vector<int32_t> raw( return literal{{shape::int32_type, dims}, s.data()};
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
} }
else if(t.data_type() == onnx::TensorProto::INT16) else if(t.data_type() == onnx::TensorProto::INT16)
{ {
std::vector<int32_t> raw( return literal{{shape::int32_type, dims}, s.data()};
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
} }
else if(t.data_type() == onnx::TensorProto::INT32) else if(t.data_type() == onnx::TensorProto::INT32)
{ {
std::vector<int32_t> raw( return literal{{shape::int32_type, dims}, s.data()};
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
} }
else if(t.data_type() == onnx::TensorProto::INT64) else if(t.data_type() == onnx::TensorProto::INT64)
{ {
std::vector<int64_t> raw( return literal{{shape::int64_type, dims}, s.data()};
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int64_type, dims}, raw};
} }
else if(t.data_type() == onnx::TensorProto::STRING) else if(t.data_type() == onnx::TensorProto::STRING)
{ {
...@@ -370,10 +352,7 @@ struct onnx_parser ...@@ -370,10 +352,7 @@ struct onnx_parser
} }
else if(t.data_type() == onnx::TensorProto::BOOL) else if(t.data_type() == onnx::TensorProto::BOOL)
{ {
std::vector<int32_t> raw( return literal{{shape::int32_type, dims}, s.data()};
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
} }
else if(t.data_type() == onnx::TensorProto::FLOAT16) else if(t.data_type() == onnx::TensorProto::FLOAT16)
{ {
...@@ -381,10 +360,7 @@ struct onnx_parser ...@@ -381,10 +360,7 @@ struct onnx_parser
} }
else if(t.data_type() == onnx::TensorProto::DOUBLE) else if(t.data_type() == onnx::TensorProto::DOUBLE)
{ {
std::vector<double> raw( return literal{{shape::double_type, dims}, s.data()};
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::double_type, dims}, raw};
} }
else if(t.data_type() == onnx::TensorProto::UINT32) else if(t.data_type() == onnx::TensorProto::UINT32)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment