Commit cac6c759 authored by Paul's avatar Paul
Browse files

Merge

parents 4bde67c4 a60bdb67
......@@ -89,7 +89,7 @@ requests==2.28.2
# via
# pygithub
# sphinx
rocm-docs-core==0.30.0
rocm-docs-core==0.30.1
# via -r requirements.in
smmap==5.0.0
# via gitdb
......
......@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Traces instructions replaced with a constant.
.. envvar:: MIGRAPHX_INT8_QUANTIZATION_PARAMS
.. envvar:: MIGRAPHX_8BITS_QUANTIZATION_PARAMS
Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the quantization parameters in only the main module.
......
......@@ -38,3 +38,6 @@ Quantize for fp16
Quantize for int8
.. option:: --fp8
Quantize for Float8E4M3FNUZ type
......@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 |
| --fp8 | Quantize for Float8E4M3FNUZ type |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
......
......@@ -23,10 +23,9 @@
#####################################################################################
google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@a6880f1e6daec99876cd6a4820fbc69c57216401 -DBUILD_FAT_LIBROCKCOMPILER=On
ROCmSoftwarePlatform/rocMLIR@08597bdd875eab888d2df826863828cdca5c8bb4 -DBUILD_FAT_LIBROCKCOMPILER=On
......@@ -81,7 +81,7 @@ add_library(migraphx
promote_literals.cpp
quantization.cpp
quantize_fp16.cpp
quantize_int8.cpp
quantize_8bits.cpp
reduce_dims.cpp
register_op.cpp
register_target.cpp
......
......@@ -231,13 +231,13 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct quantize_int8_options
{
std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {};
std::vector<parameter_map> calibration = {};
std::unordered_set<std::string> op_names = {};
};
void add_op_name(quantize_int8_options& options, const char* name)
{
options.op_names.push_back(name);
options.op_names.insert(name);
}
void add_calibration_data(quantize_int8_options& options, parameter_map& data)
......
......@@ -445,6 +445,7 @@ struct compiler
compiler_target ct;
compile_options co;
bool to_fp16 = false;
bool to_fp8 = false;
bool to_int8 = false;
std::vector<std::string> fill0;
......@@ -468,6 +469,7 @@ struct compiler
ap.set_value(true));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8e4m3fnuz type"), ap.set_value(true));
}
auto params(const program& p)
......@@ -518,6 +520,10 @@ struct compiler
{
quantize_int8(p, t, {host_params(p)});
}
if(to_fp8)
{
quantize_fp8(p, t, {host_params(p)});
}
p.compile(t, co);
l.save(p);
return p;
......
......@@ -27,6 +27,17 @@
#include <utility>
#include <migraphx/config.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -32,8 +32,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T, class F>
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
template <class T, class U, class F>
void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
......@@ -52,7 +52,8 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
s += static_cast<double>(amat(a_idx.begin(), a_idx.end())) *
static_cast<double>(bmat(b_idx.begin(), b_idx.end()));
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
......
......@@ -44,9 +44,11 @@ struct quant_dot
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(t != shape::int8_type)
std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type");
}
if(not std::all_of(
......@@ -73,6 +75,10 @@ struct quant_dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens};
}
};
......
......@@ -112,84 +112,6 @@ struct reshape
return {s0.type(), output_dyn_dims};
}
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
// can remove previous nullopts that were sent back for the alias case
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
auto n = it - start;
assert((i + n) <= istrides.size());
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
(void)d;
rstrides.push_back(stride);
}
}
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
......@@ -219,14 +141,14 @@ struct reshape
}
}
auto s = reshape_dims(inputs.front(), rdims);
auto s = shape{inputs.front().type(), rdims};
if(s->elements() != inputs.front().elements())
if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
return *s;
return s;
}
shape compute_shape(std::vector<shape> inputs) const
......
......@@ -110,22 +110,69 @@ struct reshape_lazy
return it;
}
template <class OptionalPair>
static OptionalPair try_merge_pairs(OptionalPair p2, OptionalPair p1)
{
if(not p1.has_value())
return nullopt;
if(not p2.has_value())
return nullopt;
auto dim1 = p1->first;
auto dim2 = p2->first;
auto stride1 = p1->second;
auto stride2 = p2->second;
auto elements = dim1 * dim2;
// Transposed
if(stride2 > stride1)
return nullopt;
// Broadcasted check to avoid division by zero
if(stride2 == 0)
{
if(stride1 == 0)
return {{elements, 0}};
return nullopt;
}
if(stride1 % stride2 != 0)
return nullopt;
auto space = (stride1 * dim1 + stride2 * dim2 - stride1) / stride2;
// Nonpacked
if(space != elements)
return nullopt;
return {{elements, stride2}};
}
template <class DimIterator, class StrideIterator>
static optional<std::size_t> merge_strides(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
if(dim_start == dim_last)
return nullopt;
(void)stride_start; // Is only used in the assert
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto make_pair_optional = [&](auto dim, auto stride) {
return std::make_optional(std::make_pair(dim, stride));
};
auto dim_stride_pair =
std::inner_product(std::make_reverse_iterator(dim_last - 1),
std::make_reverse_iterator(dim_start),
std::make_reverse_iterator(stride_last - 1),
make_pair_optional(*std::prev(dim_last), *std::prev(stride_last)),
MIGRAPHX_LIFT(try_merge_pairs),
make_pair_optional);
if(not dim_stride_pair.has_value())
return nullopt;
return dim_stride_pair->second;
}
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
return merge_strides(dim_start, dim_last, stride_start, stride_last).has_value();
}
// This will attempt to alias the dimensions of the input shape to the lens of
......
......@@ -44,8 +44,10 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog,
MIGRAPHX_EXPORT void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot",
"convolution"});
const std::unordered_set<std::string>& ins_names = {
"dot", "convolution"});
MIGRAPHX_EXPORT void
quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,10 +21,11 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#include <string>
#include <unordered_set>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
......@@ -37,11 +38,11 @@ struct program;
struct module;
/**
* capture inputs of operators to be quantized to int8
* capture inputs of operators to be quantized to int8 or fp8
*/
struct MIGRAPHX_EXPORT capture_arguments_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::unordered_set<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; }
......@@ -49,13 +50,13 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
};
/**
* quantize a program to int8
* quantize a program to int8 or fp8
*/
struct MIGRAPHX_EXPORT quantize_int8_pass
struct MIGRAPHX_EXPORT quantize_8bits_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
shape::type_t precision = shape::int8_type;
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; }
std::string name() const { return "quantize_8bits"; }
void apply(module& m) const;
};
......
......@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts)
smod->finalize(contexts);
}
}
#ifndef BUILD_DEV
if(std::any_of(this->begin(), this->end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
}
#endif
// Warn when an instruction is not normalized
auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
/*
*********************************************************************************
* Reference: see DynamicQuantizeLinear in *
* https://github.com/onnx/onnx/blob/main/docs/Operators.md *
*********************************************************************************
DynamicQuantizeLinear
A Function to fuse calculation for Scale, Zero Point and FP32->8Bit conversion of FP32 Input data.
Outputs Scale, ZeroPoint and Quantized Input for a given FP32 Input. Scale is calculated as:
y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin)
* where qmax and qmin are max and min values for quantization range i.e. [0, 255] in case of uint8
* data range is adjusted to include 0.
Zero point is calculated as:
intermediate_zero_point = qmin - min(x)/y_scale
y_zero_point = cast(round(saturate(itermediate_zero_point)))
* where qmax and qmin are max and min values for quantization range .i.e [0, 255] in case of uint8
* for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now
only uint8 is supported.
* rounding to nearest ties to even. Data quantization formula is:
y = saturate (round (x / y_scale) + y_zero_point)
* for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8.Right now only
uint8 is supported.
* rounding to nearest ties to even.
Version
This version of the operator has been available since version 11 of the default ONNX operator set.
Inputs
x : T1
Input tensor
Outputs
y : T2
Quantized output tensor
y_scale : tensor(float)
Output scale. It's a scalar, which means a per-tensor/layer quantization.
y_zero_point : T2
Output zero point. It's a scalar, which means a per-tensor/layer quantization.
Type Constraints
T1 : tensor(float)
Constrain 'x' to float tensor.
T2 : tensor(uint8)
Constrain 'y_zero_point' and 'y' to 8-bit unsigned integer tensor.
*/
struct parse_dynamicquantizelinear : op_parser<parse_dynamicquantizelinear>
{
std::vector<op_desc> operators() const { return {{"DynamicQuantizeLinear"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
auto x = args[0];
auto x_shape = x->get_shape();
auto x_type = x_shape.type();
if(x_shape.dynamic())
MIGRAPHX_THROW("DYNAMICQUANTIZELINEAR: dynamic shapes are not supported");
auto x_reshaped =
(x_shape.lens().size() == 1)
? x
: info.add_instruction(
migraphx::make_op("reshape", {{"dims", {x_shape.elements()}}}), x);
auto lit_0 = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0}});
x_reshaped =
info.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, lit_0);
// 1. Computing y_scale
// Note: currently, DynamicQuantizeLinear only has uint8 quantization:
const auto Q_MAX = std::numeric_limits<uint8_t>::max();
const auto Q_MIN = std::numeric_limits<uint8_t>::min();
auto q_range =
info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MAX - Q_MIN}});
// maximum(0, max(x))
auto max_x =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped);
// minimum(0, min(x))
auto min_x =
info.add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped);
// y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin)
auto sub0 = info.add_common_op("sub", max_x, min_x);
auto y_scale = info.add_common_op("div", sub0, q_range);
// 2. Computing y_zero_point
// intermediate_zero_point = qmin - min(x) / y_scale
auto q_min = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MIN}});
auto q_max = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MAX}});
auto sub1 = info.add_common_op("sub", q_min, min_x);
auto interm_zp = info.add_common_op("div", sub1, y_scale);
// y_zero_point = cast(round(saturate(itermediate_zero_point)))
auto saturate = info.add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max);
auto round = info.add_instruction(migraphx::make_op("nearbyint"), saturate);
auto y_zero_point = info.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::uint8_type}}), round);
// 3. quantize x with y_scale and y_zero_point
auto quant = bcast_qdq_instr("quantizelinear", x, y_scale, y_zero_point, info);
return {quant, y_scale, y_zero_point};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/onnx/conv.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_qlinearconcat : op_parser<parse_qlinearconcat>
{
std::vector<op_desc> operators() const { return {{"QLinearConcat"}}; }
// basic type checking for QLinearConcat Operator
void check_inputs(const std::vector<instruction_ref>& args) const
{
auto args_size = args.size();
// at least 5 input tensors:
// 1. is Y_scale: tensor(float)
// 2. is Y_zero_pont: tensor(uint8)/tensor(int8)
// remaining is a sequence of :
// 3. Tensor: tensor(uint8)/tensor(int8)
// 4. Scale: tensor(float),
// 5. ZeroPoint: tensor(uint8)/tensor(int8) tensors
// Size can be 5, 8, 11 ...
if((args_size < 5) or ((args_size - 2) % 3 != 0))
MIGRAPHX_THROW("QLINEARCONCAT: missing inputs");
auto y_zp = args[1];
auto y_zp_type = y_zp->get_shape().type();
if(y_zp_type != migraphx::shape::int8_type and y_zp_type != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARCONCAT: unsupported output type");
auto t0_type = args[2]->get_shape().type();
if(t0_type != migraphx::shape::int8_type and t0_type != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARCONCAT: unsupported input type");
for(auto idx = 2; idx < args.size(); idx += 3)
{
if((args[idx]->get_shape().type() != t0_type) or
(args[idx + 2]->get_shape().type() != t0_type))
{
MIGRAPHX_THROW("QLINEARCONCAT: mismatching input types");
}
}
}
instruction_ref parse(const op_desc& /* opd */,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(args);
if(not contains(info.attributes, "axis"))
MIGRAPHX_THROW("QLINEARCONCAT: missing axis attribute");
auto axis = parser.parse_value(info.attributes.at("axis")).template at<int64_t>();
std::vector<instruction_ref> tmp;
for(auto idx = 2; idx < args.size(); idx += 3)
{
auto data_tensor = args[idx];
auto scale = args[idx + 1];
auto zero_pt = args[idx + 2];
tmp.push_back(bcast_qdq_instr("dequantizelinear", data_tensor, scale, zero_pt, info));
}
auto y = info.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), tmp);
auto y_scale = args[0];
auto y_zero_pt = args[1];
return bcast_qdq_instr("quantizelinear", y, y_scale, y_zero_pt, info);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
py::arg("ins_names") = std::unordered_set<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -25,7 +25,7 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
......@@ -45,7 +45,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_8BITS_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
......@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}});
}
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
void quantize_8bits(program& prog,
const target& t,
shape::type_t precision,
const std::vector<parameter_map>& calibration,
const std::unordered_set<std::string>& ins_names)
{
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes(prog, {optimize_module{}});
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::shared_ptr<std::vector<std::pair<float, float>>> quant_8bit_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) {
float quantized_range = (precision == shape::type_t::int8_type) ? 127.0 : 240.0;
auto calc_quant_params = [&](std::size_t ins_index, std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
......@@ -90,23 +83,22 @@ void quantize_int8(program& prog,
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
// if all values are 0, no need to do scaling
if(max_abs_vals->at(ins_index) == 0.0f)
if(float_equal(max_abs_vals->at(ins_index), 0.0f))
{
param_pair.first = 1.0f;
}
else
{
param_pair.first = 127.0f / max_abs_vals->at(ins_index);
param_pair.first = quantized_range / max_abs_vals->at(ins_index);
}
int8_quant_params->at(ins_index) = param_pair;
quant_8bit_params->at(ins_index) = param_pair;
};
// pass to add capture argument op
std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
quant_8bit_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
......@@ -134,11 +126,11 @@ void quantize_int8(program& prog,
}
// print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
if(enabled(MIGRAPHX_8BITS_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < int8_quant_params->size(); ++i)
for(std::size_t i = 0; i < quant_8bit_params->size(); ++i)
{
auto param = int8_quant_params->at(i);
auto param = quant_8bit_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
......@@ -146,11 +138,44 @@ void quantize_int8(program& prog,
}
run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
{quantize_8bits_pass{precision, *quant_8bit_params},
simplify_qdq{},
optimize_module{},
dead_code_elimination{}});
}
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::unordered_set<std::string>& ins_names)
{
std::unordered_set<std::string> op_names = {"convolution", "dot"};
if(op_names != ins_names)
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
quantize_8bits(prog, t, shape::int8_type, calibration, ins_names);
}
void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
std::unordered_set<std::string> supported_ins_names;
auto* mm = prog.get_main_module();
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "convert")
{
continue;
}
if(not starts_with(ins->name(), "@"))
{
supported_ins_names.insert(ins->name());
}
}
quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment