Unverified Commit db3c07fb authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add `--fp8` option to quantize models in FP8 inside `migraphx-driver` (#2535)

parent aac4e950
...@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass. ...@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass.
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Traces instructions replaced with a constant. Traces instructions replaced with a constant.
.. envvar:: MIGRAPHX_INT8_QUANTIZATION_PARAMS .. envvar:: MIGRAPHX_8BITS_QUANTIZATION_PARAMS
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the quantization parameters in only the main module. Print the quantization parameters in only the main module.
......
...@@ -38,3 +38,6 @@ Quantize for fp16 ...@@ -38,3 +38,6 @@ Quantize for fp16
Quantize for int8 Quantize for int8
.. option:: --fp8
Quantize for Float8E4M3FNUZ type
...@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as ...@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel | | --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 | | --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 | | --int8 | Quantize for int8 |
| --fp8 | Quantize for Float8E4M3FNUZ type |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) | | --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) | | --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) | | --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
......
...@@ -81,7 +81,7 @@ add_library(migraphx ...@@ -81,7 +81,7 @@ add_library(migraphx
promote_literals.cpp promote_literals.cpp
quantization.cpp quantization.cpp
quantize_fp16.cpp quantize_fp16.cpp
quantize_int8.cpp quantize_8bits.cpp
reduce_dims.cpp reduce_dims.cpp
register_op.cpp register_op.cpp
register_target.cpp register_target.cpp
......
...@@ -445,6 +445,7 @@ struct compiler ...@@ -445,6 +445,7 @@ struct compiler
compiler_target ct; compiler_target ct;
compile_options co; compile_options co;
bool to_fp16 = false; bool to_fp16 = false;
bool to_fp8 = false;
bool to_int8 = false; bool to_int8 = false;
std::vector<std::string> fill0; std::vector<std::string> fill0;
...@@ -468,6 +469,7 @@ struct compiler ...@@ -468,6 +469,7 @@ struct compiler
ap.set_value(true)); ap.set_value(true));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true)); ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true)); ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8e4m3fnuz type"), ap.set_value(true));
} }
auto params(const program& p) auto params(const program& p)
...@@ -518,6 +520,10 @@ struct compiler ...@@ -518,6 +520,10 @@ struct compiler
{ {
quantize_int8(p, t, {host_params(p)}); quantize_int8(p, t, {host_params(p)});
} }
if(to_fp8)
{
quantize_fp8(p, t, {host_params(p)});
}
p.compile(t, co); p.compile(t, co);
l.save(p); l.save(p);
return p; return p;
......
...@@ -46,6 +46,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog, ...@@ -46,6 +46,8 @@ MIGRAPHX_EXPORT void quantize_int8(program& prog,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", const std::vector<std::string>& ins_names = {"dot",
"convolution"}); "convolution"});
MIGRAPHX_EXPORT void
quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP #define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -37,7 +37,7 @@ struct program; ...@@ -37,7 +37,7 @@ struct program;
struct module; struct module;
/** /**
* capture inputs of operators to be quantized to int8 * capture inputs of operators to be quantized to int8 or fp8
*/ */
struct MIGRAPHX_EXPORT capture_arguments_pass struct MIGRAPHX_EXPORT capture_arguments_pass
{ {
...@@ -49,13 +49,14 @@ struct MIGRAPHX_EXPORT capture_arguments_pass ...@@ -49,13 +49,14 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
}; };
/** /**
* quantize a program to int8 * quantize a program to int8 or fp8
*/ */
struct MIGRAPHX_EXPORT quantize_int8_pass struct MIGRAPHX_EXPORT quantize_8bits_pass
{ {
shape::type_t precision = shape::int8_type;
std::vector<std::string> ins_names = {"dot", "convolution"}; std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params; std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; } std::string name() const { return "quantize_8bits"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp> #include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_8bits.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_8BITS_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input // This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator. // from double or float to float16 by inserting a convert operator.
...@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}}); run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}});
} }
void quantize_int8(program& prog, void quantize_8bits(program& prog,
const target& t, const target& t,
shape::type_t precision,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::vector<std::string>& ins_names)
{ {
std::set<std::string> op_names = {"convolution", "dot"}; // Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// avoid loss of precision. // avoid loss of precision.
run_passes(prog, {optimize_module{}}); run_passes(prog, {optimize_module{}});
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params = std::shared_ptr<std::vector<std::pair<float, float>>> quant_8bit_params =
std::make_shared<std::vector<std::pair<float, float>>>(); std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>(); std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index, float quantized_range = (precision == shape::type_t::int8_type) ? 127.0 : 240.0;
std::vector<argument> args) { auto calc_quant_params = [&](std::size_t ins_index, std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f}; std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
...@@ -90,23 +83,22 @@ void quantize_int8(program& prog, ...@@ -90,23 +83,22 @@ void quantize_int8(program& prog,
auto min_val = *std::min_element(vec_val.begin(), vec_val.end()); auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val)); auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs); max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
// if all values are 0, no need to do scaling // if all values are 0, no need to do scaling
if(max_abs_vals->at(ins_index) == 0.0f) if(float_equal(max_abs_vals->at(ins_index), 0.0f))
{ {
param_pair.first = 1.0f; param_pair.first = 1.0f;
} }
else else
{ {
param_pair.first = 127.0f / max_abs_vals->at(ins_index); param_pair.first = quantized_range / max_abs_vals->at(ins_index);
} }
int8_quant_params->at(ins_index) = param_pair; quant_8bit_params->at(ins_index) = param_pair;
}; };
// pass to add capture argument op // pass to add capture argument op
std::size_t param_num = 0; std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}}); run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f)); quant_8bit_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f); max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale // use the calibration data to compute the quantization scale
...@@ -134,11 +126,11 @@ void quantize_int8(program& prog, ...@@ -134,11 +126,11 @@ void quantize_int8(program& prog,
} }
// print the quantization parameters in only the main module // print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{})) if(enabled(MIGRAPHX_8BITS_QUANTIZATION_PARAMS{}))
{ {
for(std::size_t i = 0; i < int8_quant_params->size(); ++i) for(std::size_t i = 0; i < quant_8bit_params->size(); ++i)
{ {
auto param = int8_quant_params->at(i); auto param = quant_8bit_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl; << ", shift = " << param.second << std::endl;
} }
...@@ -146,11 +138,46 @@ void quantize_int8(program& prog, ...@@ -146,11 +138,46 @@ void quantize_int8(program& prog,
} }
run_passes(prog, run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params}, {quantize_8bits_pass{precision, ins_names, *quant_8bit_params},
simplify_qdq{}, simplify_qdq{},
optimize_module{}, optimize_module{},
dead_code_elimination{}}); dead_code_elimination{}});
} }
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
quantize_8bits(prog, t, shape::int8_type, calibration, ins_names);
}
void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
std::vector<std::string> supported_ins_names;
auto* mm = prog.get_main_module();
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "convert")
{
continue;
}
else if(not starts_with(ins->name(), "@"))
{
supported_ins_names.push_back(ins->name());
}
}
quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_8bits.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
...@@ -41,8 +41,6 @@ ...@@ -41,8 +41,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type() static std::vector<shape::type_t>& get_quantizable_type()
{ {
static std::vector<shape::type_t> quantable_types = { static std::vector<shape::type_t> quantable_types = {
...@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type() ...@@ -50,7 +48,7 @@ static std::vector<shape::type_t>& get_quantizable_type()
return quantable_types; return quantable_types;
} }
void quantize_int8_pass::apply(module& m) const // NOLINT void quantize_8bits_pass::apply(module& m) const // NOLINT
{ {
const auto& quantizable_types = get_quantizable_type(); const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
...@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT ...@@ -66,9 +64,10 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
auto input = ins->inputs().front(); auto input = ins->inputs().front();
auto s = input->get_shape(); auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type) if(contains(quantizable_types, s.type()) and s.type() != precision)
{ {
auto zero_point = m.add_literal(static_cast<int8_t>(param.second)); auto zero_point =
m.add_literal(migraphx::literal{migraphx::shape{precision}, {param.second}});
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first})); auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens(); const auto& lens = s.lens();
scale = scale =
...@@ -87,20 +86,33 @@ void quantize_int8_pass::apply(module& m) const // NOLINT ...@@ -87,20 +86,33 @@ void quantize_int8_pass::apply(module& m) const // NOLINT
void capture_arguments_pass::apply(module& m) const // NOLINT void capture_arguments_pass::apply(module& m) const // NOLINT
{ {
assert(param_index != nullptr); assert(param_index != nullptr);
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(not contains(ins_names, ins->name())) if(not contains(ins_names, ins->name()))
{ {
continue; continue;
} }
if(ins->name() == "convert")
{
continue;
}
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::vector<instruction_ref> new_args; std::vector<instruction_ref> new_args;
for(auto input : inputs) for(auto input : inputs)
{
if(contains(quantizable_types, input->get_shape().type()))
{ {
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input); auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in); new_args.push_back(new_in);
} }
else
{
new_args.push_back(input);
}
}
m.replace_instruction(ins, ins->get_operator(), new_args); m.replace_instruction(ins, ins->get_operator(), new_args);
} }
} }
......
...@@ -210,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2) ...@@ -210,9 +210,15 @@ bool compare_literals(instruction_ref ins1, instruction_ref ins2)
bool diff_shapes_equal_vals = false; bool diff_shapes_equal_vals = false;
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) { visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
diff_shapes_equal_vals = diff_shapes_equal_vals =
std::all_of( std::all_of(l1.begin() + 1,
l1.begin() + 1, l1.end(), [&](auto v) { return float_equal(v, l1.front()); }) and l1.end(),
std::all_of(l2.begin(), l2.end(), [&](auto v) { return float_equal(v, l1.front()); }); [&](auto v) {
return ((float_equal(v, l1.front())) or
(std::isinf(l1.front()) and std::isinf(v)));
}) and
std::all_of(l2.begin(), l2.end(), [&](auto v) {
return ((float_equal(v, l1.front())) or (std::isinf(l1.front()) and std::isinf(v)));
});
}); });
return (x == y) or diff_shapes_equal_vals; return (x == y) or diff_shapes_equal_vals;
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_8bits.hpp>
#include <migraphx/quantize_fp16.hpp> #include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
...@@ -654,7 +654,8 @@ TEST_CASE(dot_float) ...@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {"dot"}, quant_params},
migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args) ...@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {"dot"}, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog()); EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p); optimize_prog_int8(p);
...@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg) ...@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_8bits_pass{migraphx::shape::int8_type, {"dot"}, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog()); EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p); optimize_prog_int8(p);
...@@ -876,7 +879,9 @@ TEST_CASE(conv_float) ...@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}}); migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"convolution"}, quant_params}});
optimize_prog_int8(p); optimize_prog_int8(p);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw) ...@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
test::throws([&] { test::throws([&] {
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"add"}, quant_params}}); migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"add"}, quant_params}});
}); });
} }
...@@ -952,7 +959,9 @@ TEST_CASE(conv_half) ...@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}}); migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"convolution"}, quant_params}});
optimize_prog_int8(p); optimize_prog_int8(p);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -1231,7 +1240,10 @@ TEST_CASE(int8_subgraph) ...@@ -1231,7 +1240,10 @@ TEST_CASE(int8_subgraph)
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes( migraphx::run_passes(
p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}}); p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}});
migraphx::run_passes(p1, {migraphx::quantize_int8_pass{{"convolution", "dot"}, quant_params}}); migraphx::run_passes(p1,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type,
{"convolution", "dot"},
quant_params}});
optimize_prog_int8(p1); optimize_prog_int8(p1);
auto p2 = create_int8_program(); auto p2 = create_int8_program();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment