Unverified Commit b73427c9 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into fix_for_multiconfig_generators

parents 55e635e5 4c059fa3
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/config.hpp>
#if MIGRAPHX_HAS_EXECUTORS
#include <execution>
#else
#include <migraphx/simple_par_for.hpp>
#endif
#include <algorithm>
#include <mutex>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
struct exception_list
{
std::vector<std::exception_ptr> exceptions;
std::mutex m;
void add_exception()
{
std::lock_guard<std::mutex> guard(m);
exceptions.push_back(std::current_exception());
}
template <class F>
auto collect(F f)
{
return [f, this](auto&&... xs) {
try
{
f(std::forward<decltype(xs)>(xs)...);
}
catch(...)
{
this->add_exception();
}
};
}
void throw_if_exception() const
{
if(not exceptions.empty())
std::rethrow_exception(exceptions.front());
}
};
} // namespace detail
template <class InputIt, class OutputIt, class UnaryOperation>
OutputIt par_transform(InputIt first1, InputIt last1, OutputIt d_first, UnaryOperation unary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(std::execution::par, first1, last1, d_first, std::move(unary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = unary_op(first1[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt1, class InputIt2, class OutputIt, class BinaryOperation>
OutputIt par_transform(
InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt d_first, BinaryOperation binary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(
std::execution::par, first1, last1, first2, d_first, std::move(binary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = binary_op(first1[i], first2[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt, class UnaryFunction>
void par_for_each(InputIt first, InputIt last, UnaryFunction f)
{
#if MIGRAPHX_HAS_EXECUTORS
// Propagate the exception
detail::exception_list ex;
std::for_each(std::execution::par, first, last, ex.collect(std::move(f)));
ex.throw_if_exception();
#else
simple_par_for(last - first, [&](auto i) { f(first[i]); });
#endif
}
template <class... Ts>
auto par_copy_if(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::copy_if(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::copy_if(std::forward<Ts>(xs)...);
#endif
}
template <class... Ts>
auto par_sort(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::sort(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::sort(std::forward<Ts>(xs)...);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
......@@ -24,93 +24,23 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
#include <migraphx/par.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f)
void par_for(std::size_t n, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f);
using iterator = basic_iota_iterator<id, std::size_t>;
par_for_each(iterator{0, {}}, iterator{n, {}}, f);
}
template <class F>
void par_for(std::size_t n, F f)
void par_for(std::size_t n, std::size_t, F f)
{
const int min_grain = 8;
par_for(n, min_grain, f);
par_for(n, f);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -21,47 +21,99 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/gather.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int64_t axis)
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
const auto& input_shape = arg1.get_shape();
auto lens = input_shape.lens();
auto axis_dim_size = lens[axis];
lens[axis] = arg2.get_shape().elements();
shape out_comp_shape{result.get_shape().type(), lens};
std::size_t nelements = result.get_shape().elements();
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
visit_all(result, arg1)([&](auto output, auto input_v) {
hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
arg2.visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements, 256)([=](auto i) __device__ {
auto idx = out_comp.multi(i);
auto in_index = indices_ptr[idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[axis] = in_index;
output_ptr[i] = input[idx];
});
template <class F>
void simple_par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
});
assert(work >= n);
}
}
template <class F>
void simple_par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
simple_par_for_impl(n, threadsize, f);
}
return result;
template <class F>
void simple_par_for(std::size_t n, F f)
{
const int min_grain = 8;
simple_par_for(n, min_grain, f);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -24,21 +24,21 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#include <utility>
#include <cstdint>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR")
inline int tune_axis(int n_dim, int axis, const std::string& op_name = "OPERATOR")
{
if(axis >= n_dim or std::abs(axis) > n_dim)
{
if(axis < 0)
axis += n_dim;
if(axis < 0 or axis >= n_dim)
MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range.");
}
return (axis < 0) ? axis + n_dim : axis;
return axis;
}
} // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,27 +21,26 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
#include <migraphx/config.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
namespace onnx {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values);
argument MIGRAPHX_DEVICE_EXPORT pad(hipStream_t stream,
argument result,
argument arg1,
float value,
std::vector<std::int64_t> pads);
instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0);
} // namespace device
} // namespace gpu
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
This diff is collapsed.
......@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::SPARSE_TENSOR:
case onnx::AttributeProto::SPARSE_TENSORS:
case onnx::AttributeProto::TYPE_PROTOS:
case onnx::AttributeProto::TYPE_PROTO:
case onnx::AttributeProto::GRAPHS: return {};
}
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
......@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT8E4M3FNUZ: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
std::transform(data_int32.begin(),
data_int32.end(),
std::back_inserter(data_fp8),
[](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; });
return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8);
}
case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64:
......@@ -609,6 +625,17 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
case 18: {
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
return shape::fp8e4m3fnuz_type;
}
case 14:
case 15:
case 16:
case 17:
case 19:
case 20:
default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
......
......@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
// use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size .
size_t batch_size = s0.lens().front();
auto rand_dummy = info.add_literal(
migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}});
auto rand_dummy = info.add_literal(migraphx::literal{
migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
randoms =
info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
}
......
......@@ -22,14 +22,8 @@
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -39,76 +33,14 @@ struct parse_pooling : op_parser<parse_pooling>
{
std::vector<op_desc> operators() const
{
return {{"AveragePool", "average"},
{"GlobalAveragePool", "average"},
{"GlobalMaxPool", "max"},
{"MaxPool", "max"},
{"LpPool", "lpnorm"},
{"GlobalLpPool", "lpnorm"}};
}
value handle_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values) const
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
if(contains(info.attributes, "dilations"))
{
values["dilations"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilations"]));
check_attr_sizes(
kdims, values["dilations"].size(), "PARSE_POOLING: inconsistent dilations");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
return {
{"AveragePool", "average"},
{"GlobalAveragePool", "average"},
{"GlobalMaxPool", "max"},
{"MaxPool", "max"},
{"LpPool", "lpnorm"},
{"GlobalLpPool", "lpnorm"},
};
}
instruction_ref parse(const op_desc& opd,
......@@ -116,148 +48,8 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
if(values["dilations"].size() != kdims)
{
values["dilations"].resize(kdims);
std::fill_n(values["dilations"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
values["dilations"].to_vector<std::size_t>(),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
}
return l1;
}
return add_pooling_op(opd, std::move(info), args[0]);
};
};
} // namespace onnx
......
......@@ -23,6 +23,7 @@
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/make_op.hpp>
......@@ -36,90 +37,56 @@ namespace onnx {
/*
*********************************************************************************
* Reference: see QLinearGlobalAveragePool in *
* Reference: see QLinearAveragePool and QLinearGlobalAveragePool in *
* github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
*/
QLinearGlobalAveragePool consumes an input tensor X and applies
Average pooling across the values in the same channel. This is
equivalent to AveragePool with kernel size equal to the spatial
dimension of input tensor. Input is of type uint8_t or int8_t.
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Attributes
channels_last : int
Inputs
X : T
Input data tensor from the previous operator; According to channels_last, dimensions for image case
are (N x C x H x W), or (N x H x W x C) where N is the batch size, C is the number of channels, and
H and W are the height and the width of the data. For non image case, the dimensions are in the form
of (N x C x D1 x D2 ... Dn), or (N x D1 X D2 ... Dn x C) where N is the batch size.
x_scale : tensor(float)
Scale of quantized input 'X'. It must be a scalar.
x_zero_point : T
Zero point tensor for input 'X'. It must be a scalar.
y_scale : tensor(float)
Scale of quantized output 'Y'. It must be a scalar.
y_zero_point : T
Zero point tensor for output 'Y'. It must be a scalar.
Outputs
Y : T
Output data tensor from pooling across the input tensor. The output tensor has the same rank as the
input. with the N and C value keep it value, while the other dimensions are all 1. Type Constraints
T : tensor(uint8), tensor(int8)
Constrain input and output types to signed/unsigned int8 tensors.
*/
struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool>
struct parse_qlinearpooling : op_parser<parse_qlinearpooling>
{
std::vector<op_desc> operators() const { return {{"QLinearGlobalAveragePool"}}; }
// basic type checking for QLinearGlobalAveragePool Operator
void check_inputs(const std::vector<instruction_ref>& args) const
std::vector<op_desc> operators() const
{
if(args.size() < 5)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: missing inputs");
return {{"QLinearGlobalAveragePool", "average"}, {"QLinearAveragePool", "average"}};
}
const auto& in_x = args[0];
const auto& zero_pt_x = args[2];
const auto& zero_pt_y = args[4];
void check_inputs(const op_desc& opd, const std::vector<instruction_ref>& args) const
{
const auto& in_x = args[0];
const auto onnx_name = opd.onnx_name;
if(in_x->get_shape().ndim() <= 2)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: input dimensions too small");
MIGRAPHX_THROW(onnx_name + ": input dimensions too small");
auto type_x = in_x->get_shape().type();
if(type_x != migraphx::shape::int8_type and type_x != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: unsupported input type");
MIGRAPHX_THROW(onnx_name + ": unsupported input type");
const auto& zero_pt_x = args[2];
if(type_x != zero_pt_x->get_shape().type())
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: input zero point");
if(type_x != zero_pt_y->get_shape().type())
MIGRAPHX_THROW("QLINEARGLOBALAVERAGEPOOL: mismatched type: output zero point");
MIGRAPHX_THROW(onnx_name + ": mismatched type: input zero point");
if(args.size() == 5)
{
const auto& zero_pt_y = args[4];
if(type_x != zero_pt_y->get_shape().type())
MIGRAPHX_THROW(onnx_name + ": mismatched type: output zero point");
}
}
instruction_ref parse(const op_desc& /* opd */,
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
int channels_last =
parser.parse_value(info.attributes.at("channels_last")).template at<int>();
if(channels_last != 0)
MIGRAPHX_THROW(
"QLINEARGLOBALAVERAGEPOOL: channels_last (N x D1..Dn x C) is not supported");
if(contains(info.attributes, "channel_last"))
{
int channels_last =
parser.parse_value(info.attributes.at("channels_last")).template at<int>();
if(channels_last != 0)
MIGRAPHX_THROW(opd.onnx_name + ": channels_last (N x D1..Dn x C) is not supported");
}
check_inputs(args);
check_inputs(opd, args);
// Input: X
......@@ -128,21 +95,18 @@ struct parse_qlinearglobalaveragepool : op_parser<parse_qlinearglobalaveragepool
const auto& zero_pt_x = args[2];
auto dquant_x = bcast_qdq_instr("dequantizelinear", in_x, scale_x, zero_pt_x, info);
// Output Y = globalaveragepool(X)
auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average};
auto lens = in_x->get_shape().lens();
std::vector<size_t> lengths(lens.begin() + 2, lens.end());
op.lengths = lengths;
op.padding = std::vector<size_t>(lens.size());
auto out_y = info.add_instruction(op, dquant_x);
// Output Y = pooling_op(X)
const auto& scale_y = args[3];
const auto& zero_pt_y = args[4];
auto out_y = add_pooling_op(opd, info, dquant_x);
auto out_quant_y = bcast_qdq_instr("quantizelinear", out_y, scale_y, zero_pt_y, info);
const auto& in_scale_y = args[3];
// zero_pt for Y is supplied as the last optional argument..
if(args.size() == 5)
return (bcast_qdq_instr("quantizelinear", out_y, in_scale_y, args[4], info));
return out_quant_y;
// if no zero_pt: just broadcast the scale..
auto bcast_scale_y = bcast_scalar_instr(out_y->get_shape(), in_scale_y, info);
return (info.add_instruction(migraphx::make_op("quantizelinear"), out_y, bcast_scale_y));
}
};
......
......@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd>
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
std::string reduction = "none";
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
reduction = info.attributes.at("reduction").s();
if(not contains({"none", "add", "mul", "min", "max"}, reduction))
{
MIGRAPHX_THROW("PARSE_SCATTERND: unsupported reduction mode " + reduction);
}
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
return info.add_instruction(migraphx::make_op("scatternd_" + reduction), args);
}
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <optional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// generate unique output stream y, given input stream x;
//
// case unsorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// output(s):
// y: [2, 1, 3, 4] --- the unique output
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
//
// case sorted:
// input x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// output(s):
// y: [1, 2, 3, 4] --- the unique output
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of indices of x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
struct parse_unique : op_parser<parse_unique>
{
std::vector<op_desc> operators() const { return {{"Unique"}}; }
std::vector<instruction_ref> parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int64_t sorted = 1; // default = sorted.
if(contains(info.attributes, "sorted"))
sorted = parser.parse_value(info.attributes.at("sorted")).at<int>();
std::optional<int64_t> axis;
if(contains(info.attributes, "axis"))
{
auto n_dim = args[0]->get_shape().ndim();
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
axis = tune_axis(n_dim, *axis, opd.op_name);
}
migraphx::argument data_arg = args.back()->eval();
auto opr = axis ? migraphx::make_op("unique", {{"axis", *axis}, {"sorted", sorted}})
: migraphx::make_op("unique", {{"sorted", sorted}});
auto u_opr = info.add_instruction(opr, args.at(0));
auto i_y = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), u_opr);
auto i_y_idx = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), u_opr);
auto i_x_idx = info.add_instruction(make_op("get_tuple_elem", {{"index", 2}}), u_opr);
auto i_count = info.add_instruction(make_op("get_tuple_elem", {{"index", 3}}), u_opr);
return {i_y, i_y_idx, i_x_idx, i_count};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/pooling.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values)
{
auto kdims = in_shape.ndim() - 2;
if(starts_with(opd.onnx_name, "Global") or starts_with(opd.onnx_name, "QLinearGlobal"))
{
// if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
}
else
{
// works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
}
if(contains(info.attributes, "ceil_mode"))
{
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
}
if(contains(info.attributes, "strides"))
{
values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
}
if(contains(info.attributes, "kernel_shape"))
{
values["lengths"].clear();
copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
check_attr_sizes(kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
if(contains(info.attributes, "dilations"))
{
values["dilations"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilations"]));
check_attr_sizes(
kdims, values["dilations"].size(), "PARSE_POOLING: inconsistent dilations");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads available only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
return values;
}
instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0)
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_pooling_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
check_attr_sizes(
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
}
if(paddings.size() != 2 * kdims)
{
paddings.resize(kdims * 2);
std::fill_n(paddings.begin(), 2 * kdims, 0);
}
if(values["padding"].size() != kdims)
{
values["padding"].resize(kdims);
std::fill_n(values["padding"].begin(), kdims, 0);
}
if(values["stride"].size() != kdims)
{
values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1);
}
if(values["dilations"].size() != kdims)
{
values["dilations"].resize(kdims);
std::fill_n(values["dilations"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
values["dilations"].to_vector<std::size_t>(),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(not slice_start.empty())
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()});
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information
slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2,
out_lens.end(),
slice_start.begin(),
slice_end.begin(),
[](auto i, auto j) { return i + j; });
}
values["padding"] = std::vector<size_t>(paddings.begin(), paddings.end());
check_asym_padding(info, l0, paddings, values, count_include_pad, pad_val);
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}), l1);
}
return l1;
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
double max_quant = 0;
double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
......@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
std::vector<double> min_data(s.elements(), min_quant);
std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
......
......@@ -27,7 +27,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/simple_par_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp>
......@@ -461,7 +461,7 @@ struct stream_info
std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; });
par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
simple_par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first);
......
......@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop_arg)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
while(next_ins != qop)
auto prev_ins = qop_arg;
std::vector<instruction_ref> ins_inbetween;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != dqins)
{
if(next_ins->name() != "dequantizelinear")
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
}
next_ins = next_ins->outputs().front();
ins_inbetween.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
auto qinp = dqins->inputs().front();
for(auto ins : reverse_iterator_for(ins_inbetween))
{
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
}
return qinp;
}
......@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
// Only INT8 or FP8 type currently supported
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::int8_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
......@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0]);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1]);
instruction_ref dq;
instruction_ref out_scale;
instruction_ref zero_point;
......
......@@ -68,6 +68,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t)
case st::int32_type: return dt::s32;
case st::int8_type: return dt::s8;
case st::uint8_type: return dt::u8;
case st::fp8e4m3fnuz_type: MIGRAPHX_THROW("fp8e4m3fnuz unsupported in DNNL");
default: MIGRAPHX_THROW("Unsupported data type");
}
}
......
......@@ -340,7 +340,6 @@ struct cpu_apply
{"reduce_min", "reduction_min"},
{"reduce_sum", "reduction_sum"},
});
extend_op("concat", "dnnl::concat");
extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution");
......@@ -376,6 +375,12 @@ struct cpu_apply
// Apply these operators first so the inputs can be const folded
for(auto it : iterator_for(*modl))
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pow")
{
apply_pow(it);
......@@ -383,6 +388,12 @@ struct cpu_apply
}
for(auto it : iterator_for(*modl))
{
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
continue;
if(it->name() == "pooling")
{
apply_pooling(it);
......
......@@ -126,7 +126,6 @@ add_library(migraphx_gpu
fuse_ck.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
gemm_impl.cpp
hip.cpp
kernel.cpp
......@@ -140,7 +139,6 @@ add_library(migraphx_gpu
nonzero.cpp
pack_args.cpp
prefuse_ops.cpp
pad.cpp
perfdb.cpp
pooling.cpp
reverse.cpp
......@@ -168,12 +166,10 @@ endfunction()
register_migraphx_gpu_ops(hip_
argmax
argmin
gather
logsoftmax
loop
multinomial
nonzero
pad
prefix_scan_sum
reverse
scatter
......@@ -263,6 +259,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
......@@ -292,10 +290,18 @@ else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif()
add_subdirectory(driver)
......
......@@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
......@@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis,
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
// disable vectorization for fp8 types
if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) {
return ishape.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
return {1, axis};
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment