Unverified Commit db3c07fb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add `--fp8` option to quantize models in FP8 inside `migraphx-driver` (#2535)

parent aac4e950
......@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Traces instructions replaced with a constant.
.. envvar:: MIGRAPHX_INT8_QUANTIZATION_PARAMS
.. envvar:: MIGRAPHX_8BITS_QUANTIZATION_PARAMS
Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the quantization parameters in only the main module.
......
......@@ -38,3 +38,6 @@ Quantize for fp16
Quantize for int8
.. option:: --fp8
Quantize for Float8E4M3FNUZ type
......@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 |
| --fp8 | Quantize for Float8E4M3FNUZ type |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
......
......@@ -81,7 +81,7 @@ add_library(migraphx
promote_literals.cpp
quantization.cpp
quantize_fp16.cpp
quantize_int8.cpp
quantize_8bits.cpp
reduce_dims.cpp
register_op.cpp
register_target.cpp
......
......@@ -445,6 +445,7 @@ struct compiler
compiler_target ct;
compile_options co;
bool to_fp16 = false;
bool to_fp8 = false;
bool to_int8 = false;
std::vector<std::string> fill0;
......@@ -468,6 +469,7 @@ struct compiler
ap.set_value(true));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8e4m3fnuz type"), ap.set_value(true));
}
auto params(const program& p)
......@@ -518,6 +520,10 @@ struct compiler
{
quantize_int8(p, t, {host_params(p)});
}
if(to_fp8)
{
quantize_fp8(p, t, {host_params(p)});
}
p.compile(t, co);
l.save(p);
return p;
......
......@@ -46,6 +46,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot",
"convolution"});
MIGRAPHX_EXPORT void
quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#include <string>
#include <vector>
......@@ -37,7 +37,7 @@ struct program;
struct module;
/**
* capture inputs of operators to be quantized to int8
* capture inputs of operators to be quantized to int8 or fp8
*/
struct MIGRAPHX_EXPORT capture_arguments_pass
{
......@@ -49,13 +49,14 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
};
/**
* quantize a program to int8
* quantize a program to int8 or fp8
*/
struct MIGRAPHX_EXPORT quantize_int8_pass
struct MIGRAPHX_EXPORT quantize_8bits_pass
{
shape::type_t precision = shape::int8_type;
std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; }
std::string name() const { return "quantize_8bits"; }
void apply(module& m) const;
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -25,7 +25,7 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
......@@ -45,7 +45,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_8BITS_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
......@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}});
}
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
void quantize_8bits(program& prog,
const target& t,
shape::type_t precision,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes(prog, {optimize_module{}});
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::shared_ptr<std::vector<std::pair<float, float>>> quant_8bit_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) {
float quantized_range = (precision == shape::type_t::int8_type) ? 127.0 : 240.0;
auto calc_quant_params = [&](std::size_t ins_index, std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
......@@ -90,23 +83,22 @@ void quantize_int8(program& prog,
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
// if all values are 0, no need to do scaling
if(max_abs_vals->at(ins_index) == 0.0f)
if(float_equal(max_abs_vals->at(ins_index), 0.0f))
{
param_pair.first = 1.0f;
}
else
{
param_pair.first = 127.0f / max_abs_vals->at(ins_index);
param_pair.first = quantized_range / max_abs_vals->at(ins_index);
}
int8_quant_params->at(ins_index) = param_pair;
quant_8bit_params->at(ins_index) = param_pair;
};
// pass to add capture argument op
std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
quant_8bit_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
......@@ -134,11 +126,11 @@ void quantize_int8(program& prog,
}
// print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
if(enabled(MIGRAPHX_8BITS_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < int8_quant_params->size(); ++i)
for(std::size_t i = 0; i < quant_8bit_params->size(); ++i)
{
auto param = int8_quant_params->at(i);
auto param = quant_8bit_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
......@@ -146,11 +138,46 @@ void quantize_int8(program& prog,
}
run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
{quantize_8bits_pass{precision, ins_names, *quant_8bit_params},
simplify_qdq{},
optimize_module{},
dead_code_elimination{}});
}
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
quantize_8bits(prog, t, shape::int8_type, calibration, ins_names);
}
void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
std::vector<std::string> supported_ins_names;
auto* mm = prog.get_main_module();
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "convert")
{
continue;
}
else if(not starts_with(ins->name(), "@"))
{
supported_ins_names.push_back(ins->name());
}
}
quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -25,7 +25,7 @@
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
......@@ -41,8 +41,6 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
......@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type()
return quantable_types;
}
void quantize_int8_pass::apply(module& m) const // NOLINT
void quantize_8bits_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
......@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type)
if(contains(quantizable_types, s.type()) and s.type() != precision)
{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto zero_point =
m.add_literal(migraphx::literal{migraphx::shape{precision}, {param.second}});
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
......@@ -87,19 +86,32 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
if(ins->name() == "convert")
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
if(contains(quantizable_types, input->get_shape().type()))
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
else
{
new_args.push_back(input);
}
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
......
......@@ -210,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
bool diff_shapes_equal_vals = false;
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
diff_shapes_equal_vals =
std::all_of(
l1.begin() + 1, l1.end(), [&](auto v) { return float_equal(v, l1.front()); }) and
std::all_of(l2.begin(), l2.end(), [&](auto v) { return float_equal(v, l1.front()); });
std::all_of(l1.begin() + 1,
l1.end(),
[&](auto v) {
return ((float_equal(v, l1.front())) or
(std::isinf(l1.front()) and std::isinf(v)));
}) and
std::all_of(l2.begin(), l2.end(), [&](auto v) {
return ((float_equal(v, l1.front())) or (std::isinf(l1.front()) and std::isinf(v)));
});
});
return (x == y) or diff_shapes_equal_vals;
......
......@@ -30,7 +30,7 @@
#include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
......@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {"dot"}, quant_params},
migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {"dot"}, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p);
......@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}});
{migraphx::quantize_8bits_pass{migraphx::shape::int8_type, {"dot"}, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p);
......@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}});
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"convolution"}, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
......@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
test::throws([&] {
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"add"}, quant_params}});
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"add"}, quant_params}});
});
}
......@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}});
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"convolution"}, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
......@@ -1231,7 +1240,10 @@ TEST_CASE(int8_subgraph)
std::size_t param_index = 0;
migraphx::run_passes(
p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}});
migraphx::run_passes(p1, {migraphx::quantize_int8_pass{{"convolution", "dot"}, quant_params}});
migraphx::run_passes(p1,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type,
{"convolution", "dot"},
quant_params}});
optimize_prog_int8(p1);
auto p2 = create_int8_program();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment