"src/vscode:/vscode.git/clone" did not exist on "0eb33f28a6a4d7e8c2d1161937816f48592b50fb"
Commit 752cb65a authored by Umang Yadav's avatar Umang Yadav
Browse files

WIP mobilenet

parent 4926f035
...@@ -72,8 +72,8 @@ struct dequantizelinear ...@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) { visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) { visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) - output[i] = static_cast<double>(static_cast<double>(input[i]) -
static_cast<int64_t>(zero_pts[i])) * static_cast<double>(zero_pts[i])) *
scales[i]; scales[i];
}); });
}); });
......
...@@ -58,6 +58,10 @@ struct quantizelinear ...@@ -58,6 +58,10 @@ struct quantizelinear
{ {
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()}; return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
} }
if(inputs[0].type() == shape::float_type)
{
return {shape::fp8e4m3fnuz_type, inputs[0].lens(), inputs[0].strides()};
}
return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()}; return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()};
} }
...@@ -80,10 +84,10 @@ struct quantizelinear ...@@ -80,10 +84,10 @@ struct quantizelinear
auto min_value = std::numeric_limits<quant_type>::min(); auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max(); auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) + double quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]); static_cast<double>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value), output[i] = std::max(static_cast<double>(min_value),
std::min(static_cast<int64_t>(max_value), quantized)); std::min(static_cast<double>(max_value), quantized));
}); });
}); });
}); });
......
...@@ -549,7 +549,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -549,7 +549,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data()); return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT8E4M3FNUZ: { case onnx::TensorProto::FLOAT8E4M3FN: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end()); std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8; std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
std::transform(data_int32.begin(), std::transform(data_int32.begin(),
...@@ -560,7 +560,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -560,7 +560,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
} }
case onnx::TensorProto::FLOAT8E5M2FNUZ: case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2: case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN: case onnx::TensorProto::FLOAT8E4M3FNUZ:
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING: case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
...@@ -625,11 +625,11 @@ shape::type_t get_type(int dtype) ...@@ -625,11 +625,11 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type; case 17: return shape::fp8e4m3fnuz_type;
case 14: case 14:
case 15: case 15:
case 16: case 16:
case 17: case 18:
case 19: case 19:
case 20: case 20:
default: { default: {
......
...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
} }
int64_t max_quant = 0; double max_quant = 0;
int64_t min_quant = 0; double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) { ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{})) if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{ {
std::vector<int> min_data(s.elements(), min_quant); std::vector<double> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant); std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data)); min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data)); max_arg = m.add_literal(literal(s, max_data));
} }
......
...@@ -125,9 +125,11 @@ struct match_find_quantizable_ops ...@@ -125,9 +125,11 @@ struct match_find_quantizable_ops
auto zp1 = r.instructions["zp1"]; auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"]; auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported // Only INT8 or FP8 type currently supported
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type) migraphx::shape::int8_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return; return;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed) // Only symmetric quantization supported (ie. non-zero zero_points not allowed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment