Commit cac6c759 authored by Paul's avatar Paul
Browse files

Merge

parents 4bde67c4 a60bdb67
...@@ -89,7 +89,7 @@ requests==2.28.2 ...@@ -89,7 +89,7 @@ requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.30.0 rocm-docs-core==0.30.1
# via -r requirements.in # via -r requirements.in
smmap==5.0.0 smmap==5.0.0
# via gitdb # via gitdb
......
...@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass. ...@@ -82,7 +82,7 @@ Print debug statements for the ``schedule`` pass.
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Traces instructions replaced with a constant. Traces instructions replaced with a constant.
.. envvar:: MIGRAPHX_INT8_QUANTIZATION_PARAMS .. envvar:: MIGRAPHX_8BITS_QUANTIZATION_PARAMS
Set to "1", "enable", "enabled", "yes", or "true" to use. Set to "1", "enable", "enabled", "yes", or "true" to use.
Print the quantization parameters in only the main module. Print the quantization parameters in only the main module.
......
...@@ -38,3 +38,6 @@ Quantize for fp16 ...@@ -38,3 +38,6 @@ Quantize for fp16
Quantize for int8 Quantize for int8
.. option:: --fp8
Quantize for Float8E4M3FNUZ type
...@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as ...@@ -55,6 +55,7 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel | | --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 | | --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 | | --int8 | Quantize for int8 |
| --fp8 | Quantize for Float8E4M3FNUZ type |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) | | --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) | | --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) | | --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
......
...@@ -23,10 +23,9 @@ ...@@ -23,10 +23,9 @@
##################################################################################### #####################################################################################
google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0 nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.6.0 ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@a6880f1e6daec99876cd6a4820fbc69c57216401 -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@08597bdd875eab888d2df826863828cdca5c8bb4 -DBUILD_FAT_LIBROCKCOMPILER=On
...@@ -81,7 +81,7 @@ add_library(migraphx ...@@ -81,7 +81,7 @@ add_library(migraphx
promote_literals.cpp promote_literals.cpp
quantization.cpp quantization.cpp
quantize_fp16.cpp quantize_fp16.cpp
quantize_int8.cpp quantize_8bits.cpp
reduce_dims.cpp reduce_dims.cpp
register_op.cpp register_op.cpp
register_target.cpp register_target.cpp
......
...@@ -231,13 +231,13 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) ...@@ -231,13 +231,13 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct quantize_int8_options struct quantize_int8_options
{ {
std::vector<parameter_map> calibration = {}; std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {}; std::unordered_set<std::string> op_names = {};
}; };
void add_op_name(quantize_int8_options& options, const char* name) void add_op_name(quantize_int8_options& options, const char* name)
{ {
options.op_names.push_back(name); options.op_names.insert(name);
} }
void add_calibration_data(quantize_int8_options& options, parameter_map& data) void add_calibration_data(quantize_int8_options& options, parameter_map& data)
......
...@@ -445,6 +445,7 @@ struct compiler ...@@ -445,6 +445,7 @@ struct compiler
compiler_target ct; compiler_target ct;
compile_options co; compile_options co;
bool to_fp16 = false; bool to_fp16 = false;
bool to_fp8 = false;
bool to_int8 = false; bool to_int8 = false;
std::vector<std::string> fill0; std::vector<std::string> fill0;
...@@ -468,6 +469,7 @@ struct compiler ...@@ -468,6 +469,7 @@ struct compiler
ap.set_value(true)); ap.set_value(true));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true)); ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true)); ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8e4m3fnuz type"), ap.set_value(true));
} }
auto params(const program& p) auto params(const program& p)
...@@ -518,6 +520,10 @@ struct compiler ...@@ -518,6 +520,10 @@ struct compiler
{ {
quantize_int8(p, t, {host_params(p)}); quantize_int8(p, t, {host_params(p)});
} }
if(to_fp8)
{
quantize_fp8(p, t, {host_params(p)});
}
p.compile(t, co); p.compile(t, co);
l.save(p); l.save(p);
return p; return p;
......
...@@ -27,6 +27,17 @@ ...@@ -27,6 +27,17 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -32,8 +32,8 @@ ...@@ -32,8 +32,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T, class F> template <class T, class U, class F>
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta) void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta)
{ {
std::size_t n_dims = cmat.get_shape().lens().size(); std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2; std::size_t dim_0 = n_dims - 2;
...@@ -52,7 +52,8 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha ...@@ -52,7 +52,8 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
double s = 0.0; double s = 0.0;
dfor(k)([&](auto kk) { dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk; a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); s += static_cast<double>(amat(a_idx.begin(), a_idx.end())) *
static_cast<double>(bmat(b_idx.begin(), b_idx.end()));
}); });
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
}); });
......
...@@ -44,9 +44,11 @@ struct quant_dot ...@@ -44,9 +44,11 @@ struct quant_dot
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(t != shape::int8_type) std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
{ {
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t"); MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type");
} }
if(not std::all_of( if(not std::all_of(
...@@ -73,6 +75,10 @@ struct quant_dot ...@@ -73,6 +75,10 @@ struct quant_dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens}; return {shape::int32_type, out_lens};
} }
}; };
......
...@@ -112,84 +112,6 @@ struct reshape ...@@ -112,84 +112,6 @@ struct reshape
return {s0.type(), output_dyn_dims}; return {s0.type(), output_dyn_dims};
} }
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
// can remove previous nullopts that were sent back for the alias case
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
auto n = it - start;
assert((i + n) <= istrides.size());
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
(void)d;
rstrides.push_back(stride);
}
}
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
...@@ -219,14 +141,14 @@ struct reshape ...@@ -219,14 +141,14 @@ struct reshape
} }
} }
auto s = reshape_dims(inputs.front(), rdims); auto s = shape{inputs.front().type(), rdims};
if(s->elements() != inputs.front().elements()) if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " + std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
return *s; return s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -110,22 +110,69 @@ struct reshape_lazy ...@@ -110,22 +110,69 @@ struct reshape_lazy
return it; return it;
} }
template <class OptionalPair>
static OptionalPair try_merge_pairs(OptionalPair p2, OptionalPair p1)
{
if(not p1.has_value())
return nullopt;
if(not p2.has_value())
return nullopt;
auto dim1 = p1->first;
auto dim2 = p2->first;
auto stride1 = p1->second;
auto stride2 = p2->second;
auto elements = dim1 * dim2;
// Transposed
if(stride2 > stride1)
return nullopt;
// Broadcasted check to avoid division by zero
if(stride2 == 0)
{
if(stride1 == 0)
return {{elements, 0}};
return nullopt;
}
if(stride1 % stride2 != 0)
return nullopt;
auto space = (stride1 * dim1 + stride2 * dim2 - stride1) / stride2;
// Nonpacked
if(space != elements)
return nullopt;
return {{elements, stride2}};
}
template <class DimIterator, class StrideIterator>
static optional<std::size_t> merge_strides(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
if(dim_start == dim_last)
return nullopt;
(void)stride_start; // Is only used in the assert
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto make_pair_optional = [&](auto dim, auto stride) {
return std::make_optional(std::make_pair(dim, stride));
};
auto dim_stride_pair =
std::inner_product(std::make_reverse_iterator(dim_last - 1),
std::make_reverse_iterator(dim_start),
std::make_reverse_iterator(stride_last - 1),
make_pair_optional(*std::prev(dim_last), *std::prev(stride_last)),
MIGRAPHX_LIFT(try_merge_pairs),
make_pair_optional);
if(not dim_stride_pair.has_value())
return nullopt;
return dim_stride_pair->second;
}
template <class DimIterator, class StrideIterator> template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start, static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last, DimIterator dim_last,
StrideIterator stride_start, StrideIterator stride_start,
StrideIterator stride_last) StrideIterator stride_last)
{ {
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last)); return merge_strides(dim_start, dim_last, stride_start, stride_last).has_value();
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
} }
// This will attempt to alias the dimensions of the input shape to the lens of // This will attempt to alias the dimensions of the input shape to the lens of
......
...@@ -44,8 +44,10 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog, ...@@ -44,8 +44,10 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog,
MIGRAPHX_EXPORT void quantize_int8(program& prog, MIGRAPHX_EXPORT void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", const std::unordered_set<std::string>& ins_names = {
"convolution"}); "dot", "convolution"});
MIGRAPHX_EXPORT void
quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,10 +21,11 @@ ...@@ -21,10 +21,11 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP #define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -37,11 +38,11 @@ struct program; ...@@ -37,11 +38,11 @@ struct program;
struct module; struct module;
/** /**
* capture inputs of operators to be quantized to int8 * capture inputs of operators to be quantized to int8 or fp8
*/ */
struct MIGRAPHX_EXPORT capture_arguments_pass struct MIGRAPHX_EXPORT capture_arguments_pass
{ {
std::vector<std::string> ins_names = {"dot", "convolution"}; std::unordered_set<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{}; std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr; std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; } std::string name() const { return "capture_arguments"; }
...@@ -49,13 +50,13 @@ struct MIGRAPHX_EXPORT capture_arguments_pass ...@@ -49,13 +50,13 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
}; };
/** /**
* quantize a program to int8 * quantize a program to int8 or fp8
*/ */
struct MIGRAPHX_EXPORT quantize_int8_pass struct MIGRAPHX_EXPORT quantize_8bits_pass
{ {
std::vector<std::string> ins_names = {"dot", "convolution"}; shape::type_t precision = shape::int8_type;
std::vector<std::pair<float, float>> quant_params; std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; } std::string name() const { return "quantize_8bits"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
...@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts) ...@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts)
smod->finalize(contexts); smod->finalize(contexts);
} }
} }
#ifndef BUILD_DEV
if(std::any_of(this->begin(), this->end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
}
#endif
// Warn when an instruction is not normalized // Warn when an instruction is not normalized
auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); }); auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/common.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
/*
*********************************************************************************
* Reference: see DynamicQuantizeLinear in *
* https://github.com/onnx/onnx/blob/main/docs/Operators.md *
*********************************************************************************
DynamicQuantizeLinear
A Function to fuse calculation for Scale, Zero Point and FP32->8Bit conversion of FP32 Input data.
Outputs Scale, ZeroPoint and Quantized Input for a given FP32 Input. Scale is calculated as:
y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin)
* where qmax and qmin are max and min values for quantization range i.e. [0, 255] in case of uint8
* data range is adjusted to include 0.
Zero point is calculated as:
intermediate_zero_point = qmin - min(x)/y_scale
y_zero_point = cast(round(saturate(itermediate_zero_point)))
* where qmax and qmin are max and min values for quantization range .i.e [0, 255] in case of uint8
* for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now
only uint8 is supported.
* rounding to nearest ties to even. Data quantization formula is:
y = saturate (round (x / y_scale) + y_zero_point)
* for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8.Right now only
uint8 is supported.
* rounding to nearest ties to even.
Version
This version of the operator has been available since version 11 of the default ONNX operator set.
Inputs
x : T1
Input tensor
Outputs
y : T2
Quantized output tensor
y_scale : tensor(float)
Output scale. It's a scalar, which means a per-tensor/layer quantization.
y_zero_point : T2
Output zero point. It's a scalar, which means a per-tensor/layer quantization.
Type Constraints
T1 : tensor(float)
Constrain 'x' to float tensor.
T2 : tensor(uint8)
Constrain 'y_zero_point' and 'y' to 8-bit unsigned integer tensor.
*/
struct parse_dynamicquantizelinear : op_parser<parse_dynamicquantizelinear>
{
std::vector<op_desc> operators() const { return {{"DynamicQuantizeLinear"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
auto x = args[0];
auto x_shape = x->get_shape();
auto x_type = x_shape.type();
if(x_shape.dynamic())
MIGRAPHX_THROW("DYNAMICQUANTIZELINEAR: dynamic shapes are not supported");
auto x_reshaped =
(x_shape.lens().size() == 1)
? x
: info.add_instruction(
migraphx::make_op("reshape", {{"dims", {x_shape.elements()}}}), x);
auto lit_0 = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0}});
x_reshaped =
info.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, lit_0);
// 1. Computing y_scale
// Note: currently, DynamicQuantizeLinear only has uint8 quantization:
const auto Q_MAX = std::numeric_limits<uint8_t>::max();
const auto Q_MIN = std::numeric_limits<uint8_t>::min();
auto q_range =
info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MAX - Q_MIN}});
// maximum(0, max(x))
auto max_x =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped);
// minimum(0, min(x))
auto min_x =
info.add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped);
// y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin)
auto sub0 = info.add_common_op("sub", max_x, min_x);
auto y_scale = info.add_common_op("div", sub0, q_range);
// 2. Computing y_zero_point
// intermediate_zero_point = qmin - min(x) / y_scale
auto q_min = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MIN}});
auto q_max = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MAX}});
auto sub1 = info.add_common_op("sub", q_min, min_x);
auto interm_zp = info.add_common_op("div", sub1, y_scale);
// y_zero_point = cast(round(saturate(itermediate_zero_point)))
auto saturate = info.add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max);
auto round = info.add_instruction(migraphx::make_op("nearbyint"), saturate);
auto y_zero_point = info.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::uint8_type}}), round);
// 3. quantize x with y_scale and y_zero_point
auto quant = bcast_qdq_instr("quantizelinear", x, y_scale, y_zero_point, info);
return {quant, y_scale, y_zero_point};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/onnx/conv.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/broadcast_qdq.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_qlinearconcat : op_parser<parse_qlinearconcat>
{
std::vector<op_desc> operators() const { return {{"QLinearConcat"}}; }
// basic type checking for QLinearConcat Operator
void check_inputs(const std::vector<instruction_ref>& args) const
{
auto args_size = args.size();
// at least 5 input tensors:
// 1. is Y_scale: tensor(float)
// 2. is Y_zero_pont: tensor(uint8)/tensor(int8)
// remaining is a sequence of :
// 3. Tensor: tensor(uint8)/tensor(int8)
// 4. Scale: tensor(float),
// 5. ZeroPoint: tensor(uint8)/tensor(int8) tensors
// Size can be 5, 8, 11 ...
if((args_size < 5) or ((args_size - 2) % 3 != 0))
MIGRAPHX_THROW("QLINEARCONCAT: missing inputs");
auto y_zp = args[1];
auto y_zp_type = y_zp->get_shape().type();
if(y_zp_type != migraphx::shape::int8_type and y_zp_type != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARCONCAT: unsupported output type");
auto t0_type = args[2]->get_shape().type();
if(t0_type != migraphx::shape::int8_type and t0_type != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARCONCAT: unsupported input type");
for(auto idx = 2; idx < args.size(); idx += 3)
{
if((args[idx]->get_shape().type() != t0_type) or
(args[idx + 2]->get_shape().type() != t0_type))
{
MIGRAPHX_THROW("QLINEARCONCAT: mismatching input types");
}
}
}
instruction_ref parse(const op_desc& /* opd */,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(args);
if(not contains(info.attributes, "axis"))
MIGRAPHX_THROW("QLINEARCONCAT: missing axis attribute");
auto axis = parser.parse_value(info.attributes.at("axis")).template at<int64_t>();
std::vector<instruction_ref> tmp;
for(auto idx = 2; idx < args.size(); idx += 3)
{
auto data_tensor = args[idx];
auto scale = args[idx + 1];
auto zero_pt = args[idx + 2];
tmp.push_back(bcast_qdq_instr("dequantizelinear", data_tensor, scale, zero_pt, info));
}
auto y = info.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), tmp);
auto y_scale = args[0];
auto y_zero_pt = args[1];
return bcast_qdq_instr("quantizelinear", y, y_scale, y_zero_pt, info);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("prog"), py::arg("prog"),
py::arg("t"), py::arg("t"),
py::arg("calibration") = std::vector<migraphx::parameter_map>{}, py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"}); py::arg("ins_names") = std::unordered_set<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp> #include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_8bits.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_8BITS_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input // This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator. // from double or float to float16 by inserting a convert operator.
...@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -57,29 +57,22 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}}); run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}});
} }
void quantize_int8(program& prog, void quantize_8bits(program& prog,
const target& t, const target& t,
const std::vector<parameter_map>& calibration, shape::type_t precision,
const std::vector<std::string>& ins_names) const std::vector<parameter_map>& calibration,
const std::unordered_set<std::string>& ins_names)
{ {
std::set<std::string> op_names = {"convolution", "dot"}; // Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
// Run optimize_module() before converting to int8 to const eval and fold in FP32 to
// avoid loss of precision. // avoid loss of precision.
run_passes(prog, {optimize_module{}}); run_passes(prog, {optimize_module{}});
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params = std::shared_ptr<std::vector<std::pair<float, float>>> quant_8bit_params =
std::make_shared<std::vector<std::pair<float, float>>>(); std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>(); std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index, float quantized_range = (precision == shape::type_t::int8_type) ? 127.0 : 240.0;
std::vector<argument> args) { auto calc_quant_params = [&](std::size_t ins_index, std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f}; std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
...@@ -90,23 +83,22 @@ void quantize_int8(program& prog, ...@@ -90,23 +83,22 @@ void quantize_int8(program& prog,
auto min_val = *std::min_element(vec_val.begin(), vec_val.end()); auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val)); auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs); max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
// if all values are 0, no need to do scaling // if all values are 0, no need to do scaling
if(max_abs_vals->at(ins_index) == 0.0f) if(float_equal(max_abs_vals->at(ins_index), 0.0f))
{ {
param_pair.first = 1.0f; param_pair.first = 1.0f;
} }
else else
{ {
param_pair.first = 127.0f / max_abs_vals->at(ins_index); param_pair.first = quantized_range / max_abs_vals->at(ins_index);
} }
int8_quant_params->at(ins_index) = param_pair; quant_8bit_params->at(ins_index) = param_pair;
}; };
// pass to add capture argument op // pass to add capture argument op
std::size_t param_num = 0; std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}}); run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f)); quant_8bit_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f); max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale // use the calibration data to compute the quantization scale
...@@ -134,11 +126,11 @@ void quantize_int8(program& prog, ...@@ -134,11 +126,11 @@ void quantize_int8(program& prog,
} }
// print the quantization parameters in only the main module // print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{})) if(enabled(MIGRAPHX_8BITS_QUANTIZATION_PARAMS{}))
{ {
for(std::size_t i = 0; i < int8_quant_params->size(); ++i) for(std::size_t i = 0; i < quant_8bit_params->size(); ++i)
{ {
auto param = int8_quant_params->at(i); auto param = quant_8bit_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl; << ", shift = " << param.second << std::endl;
} }
...@@ -146,11 +138,44 @@ void quantize_int8(program& prog, ...@@ -146,11 +138,44 @@ void quantize_int8(program& prog,
} }
run_passes(prog, run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params}, {quantize_8bits_pass{precision, *quant_8bit_params},
simplify_qdq{}, simplify_qdq{},
optimize_module{}, optimize_module{},
dead_code_elimination{}}); dead_code_elimination{}});
} }
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::unordered_set<std::string>& ins_names)
{
std::unordered_set<std::string> op_names = {"convolution", "dot"};
if(op_names != ins_names)
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
quantize_8bits(prog, t, shape::int8_type, calibration, ins_names);
}
void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
std::unordered_set<std::string> supported_ins_names;
auto* mm = prog.get_main_module();
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "convert")
{
continue;
}
if(not starts_with(ins->name(), "@"))
{
supported_ins_names.insert(ins->name());
}
}
quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment