"src/targets/vscode:/vscode.git/clone" did not exist on "26bd92d87703c18ee0be8d9ce5a90286b8e0f5db"
Commit 752cb65a authored by Umang Yadav's avatar Umang Yadav
Browse files

WIP mobilenet

parent 4926f035
......@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
output[i] = static_cast<double>(static_cast<double>(input[i]) -
static_cast<double>(zero_pts[i])) *
scales[i];
});
});
......
......@@ -58,6 +58,10 @@ struct quantizelinear
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
}
if(inputs[0].type() == shape::float_type)
{
return {shape::fp8e4m3fnuz_type, inputs[0].lens(), inputs[0].strides()};
}
return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()};
}
......@@ -80,10 +84,10 @@ struct quantizelinear
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
double quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<double>(zero_pts[i]);
output[i] = std::max(static_cast<double>(min_value),
std::min(static_cast<double>(max_value), quantized));
});
});
});
......
......@@ -549,7 +549,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT8E4M3FNUZ: {
case onnx::TensorProto::FLOAT8E4M3FN: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
std::transform(data_int32.begin(),
......@@ -560,7 +560,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
}
case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
case onnx::TensorProto::FLOAT8E4M3FNUZ:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64:
......@@ -625,11 +625,11 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type;
case 17: return shape::fp8e4m3fnuz_type;
case 14:
case 15:
case 16:
case 17:
case 18:
case 19:
case 20:
default: {
......
......@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
double max_quant = 0;
double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
......@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
std::vector<double> min_data(s.elements(), min_quant);
std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
......
......@@ -125,9 +125,11 @@ struct match_find_quantizable_ops
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
// Only INT8 or FP8 type currently supported
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::int8_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment