"example/vscode:/vscode.git/clone" did not exist on "4d40b1974e18e9215067fb4b1117213e69a2923e"
Commit f3ddd797 authored by Paul's avatar Paul
Browse files

Add fp16 to onnx

parent 9ca0fbf1
...@@ -520,7 +520,7 @@ struct onnx_parser ...@@ -520,7 +520,7 @@ struct onnx_parser
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()}; case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()};
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()}; case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error(""); case onnx::TensorProto::FLOAT16: return literal{{shape::half_type, dims}, s.data()};
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()}; case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()};
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
...@@ -548,7 +548,8 @@ struct onnx_parser ...@@ -548,7 +548,8 @@ struct onnx_parser
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error(""); case onnx::TensorProto::FLOAT16:
return literal{{shape::half_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return literal{
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()}; {shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
...@@ -579,8 +580,7 @@ struct onnx_parser ...@@ -579,8 +580,7 @@ struct onnx_parser
break; // throw std::runtime_error("Unsupported type STRING"); break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL"); break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
break; // throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break; case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break; case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break; case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
......
...@@ -132,6 +132,8 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -132,6 +132,8 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
{ {
if(ins->name() != "gpu::convolution") if(ins->name() != "gpu::convolution")
return false; return false;
if(ins->get_shape().type() != shape::float_type)
return false;
auto wei = ins->inputs().at(1)->get_shape(); auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4); assert(wei.lens().size() == 4);
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment