"sgl-kernel/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "2b1da821b5b8c1c887942efd2918c8c04ad0821a"
Commit f69d828d authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir-attention

parents fe36d210 24148857
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <array>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -38,6 +39,18 @@ namespace op { ...@@ -38,6 +39,18 @@ namespace op {
/** /**
* Slice operator that accepts variable axes, starts and ends. * Slice operator that accepts variable axes, starts and ends.
* All of `starts`, `ends`, and `axes` must be supplied by either
* their attribute or an input (but not both).
*
* Valid calls:
* slice(input); axes, starts, ends set
* slice(input, starts); axes, ends set
* slice(input, ends); starts, axes set
* slice(input, axes); starts, ends set
* slice(input, starts, ends); axes set
* slice(input, starts, axes); ends set
* slice(input, ends, axes); starts set
* slice(input, start, ends, axes); none set
* *
* Attributes: * Attributes:
* axes: constant axes to slice over (optional) * axes: constant axes to slice over (optional)
...@@ -46,8 +59,8 @@ namespace op { ...@@ -46,8 +59,8 @@ namespace op {
* *
* Parameters: * Parameters:
* data: the input tensor to slice (dynamic or static shape) * data: the input tensor to slice (dynamic or static shape)
* input_starts: starting indicies of slice (optional, static shape) * input_starts: starting indices of slice (optional, static shape)
* input_ends: ending indicies of slice (optional, static shape) * input_ends: ending indices of slice (optional, static shape)
* input_axes: axes to slice over (optional, static shape) * input_axes: axes to slice over (optional, static shape)
*/ */
struct slice struct slice
...@@ -56,6 +69,18 @@ struct slice ...@@ -56,6 +69,18 @@ struct slice
std::vector<int64_t> starts{}; std::vector<int64_t> starts{};
std::vector<int64_t> ends{}; std::vector<int64_t> ends{};
/**
* Named arrays for the set attribute possibilities.
*/
static constexpr std::array<bool, 3> all_set = {true, true, true};
static constexpr std::array<bool, 3> ends_axes = {false, true, true};
static constexpr std::array<bool, 3> starts_axes = {true, false, true};
static constexpr std::array<bool, 3> starts_ends = {true, true, false};
static constexpr std::array<bool, 3> axes_only = {false, false, true};
static constexpr std::array<bool, 3> ends_only = {false, true, false};
static constexpr std::array<bool, 3> starts_only = {true, false, false};
static constexpr std::array<bool, 3> none_set = {false, false, false};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -63,24 +88,26 @@ struct slice ...@@ -63,24 +88,26 @@ struct slice
} }
/** /**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are * Ensure that attribute axes is within limits.
* within limits. * Will attempt to normalize starts and ends; but will use the dynamic_dimension.max
* values for dynamic shapes. This makes it so you have to renormalize for
* non-fixed dynamic_dimensions.
*/ */
value attributes() const value attributes() const
{ {
value normalize = value::object{}; value normalize_axes = value::object{};
normalize["axes"] = value::array{normalize_attribute::include_min}; normalize_axes["axes"] = value::array{normalize_attribute::include_min};
normalize["starts"] = value::array{normalize_attribute::clip_max, normalize_axes["starts"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min, normalize_attribute::clip_min,
normalize_attribute::include_max, normalize_attribute::include_max,
normalize_attribute::use_len, normalize_attribute::use_len,
normalize_attribute::include_min}; normalize_attribute::include_min};
normalize["ends"] = value::array{normalize_attribute::clip_max, normalize_axes["ends"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min, normalize_attribute::clip_min,
normalize_attribute::include_max, normalize_attribute::include_max,
normalize_attribute::use_len, normalize_attribute::use_len,
normalize_attribute::include_min}; normalize_attribute::include_min};
return {{"normalize_axes", normalize}}; return {{"normalize_axes", normalize_axes}};
} }
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
...@@ -88,7 +115,7 @@ struct slice ...@@ -88,7 +115,7 @@ struct slice
/** /**
* Computes the slice output shape dimensions for given starts, ends,and axes. * Computes the slice output shape dimensions for given starts, ends,and axes.
* Templated to also handle tensor views. * Templated to also handle tensor views.
* Possibily different type between [in_starts, in_ends] and [in_axes] if in_axes is this * Possibly different type between [in_starts, in_ends] and [in_axes] if in_axes is this
* object's axes attribute. Assumes in_starts and in_ends are normalized; in_axes are valid. * object's axes attribute. Assumes in_starts and in_ends are normalized; in_axes are valid.
*/ */
template <class A, class B> template <class A, class B>
...@@ -104,62 +131,160 @@ struct slice ...@@ -104,62 +131,160 @@ struct slice
return new_lens; return new_lens;
} }
shape normalize_compute_shape(std::vector<shape> inputs) const /// Get the attributes that are non-empty
std::array<bool, 3> get_set_attributes() const
{ {
check_shapes{inputs, *this, true}.has(1, 3, 4); std::array<std::vector<int64_t>, 3> attrs = {this->starts, this->ends, this->axes};
auto input_shape = inputs[0]; std::array<bool, 3> bool_vec;
if(inputs.size() == 1) std::transform(
attrs.cbegin(), attrs.cend(), bool_vec.begin(), [](auto a) { return not a.empty(); });
return bool_vec;
}
/// Helper function for normalize_compute_shape()
shape compute_two_or_more(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto set_attributes = get_set_attributes();
// check that inputs [1, end) are all 1D, have the same
// dimension, and are static
check_shapes{inputs.begin() + 1,
inputs.end(),
std::string("SLICE: inputs (starts, ends, and input_axes)"),
false}
.only_dims(1)
.same_dims();
auto dds = input_shape.to_dynamic().dyn_dims();
if(inputs.size() == 2)
{ {
auto t = input_shape.type(); if(set_attributes == ends_axes)
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{ {
MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis "); // attr ends and axes set; inputs are (data, input_starts)
if(inputs[1].lens().at(0) != axes.size())
{
MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch");
}
std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) {
dds.at(axis) = {0, dds.at(axis).max};
});
} }
if(input_shape.dynamic()) else if(set_attributes == starts_axes)
{ {
return shape{t, // attr starts and axes set; inputs are (data, input_ends)
lens_calc(input_shape.min_lens(), starts, ends, axes), if(inputs[1].lens().at(0) != axes.size())
lens_calc(input_shape.max_lens(), starts, ends, axes), {
{}}; MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch");
}
std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) {
dds.at(axis) = {0, dds.at(axis).max};
});
}
else if(set_attributes == starts_ends)
{
// attr starts and ends set; inputs are (data, input_axes)
if(inputs[1].lens().at(0) != starts.size())
{
MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch");
}
std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
return shape::dynamic_dimension{0, dd.max};
});
} }
else else
{ {
return shape{ MIGRAPHX_THROW("SLICE: Invalid 2 input and attributes configuration");
t, lens_calc(input_shape.lens(), starts, ends, axes), input_shape.strides()};
} }
} }
else else if(inputs.size() == 3)
{ {
// check that starts, ends, and optionally input_axes are all 1D, have the same if(set_attributes == axes_only)
// dimension, and are static
check_shapes{inputs.begin() + 1,
inputs.end(),
std::string("SLICE: inputs (starts, ends, and input_axes)"),
false}
.only_dims(1)
.same_dims();
auto dds = input_shape.to_dynamic().dyn_dims();
if(inputs.size() == 3)
{ {
// attr axes set; inputs are (data, input_starts, input_ends)
if(inputs[1].lens().at(0) != axes.size()) if(inputs[1].lens().at(0) != axes.size())
{ {
MIGRAPHX_THROW("SLICE: inputs starts and ends do not have the same dimension " MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch");
"as the axes attribute");
} }
std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) {
dds.at(axis) = {0, dds.at(axis).max}; dds.at(axis) = {0, dds.at(axis).max};
}); });
} }
else else if(set_attributes == ends_only)
{
// attr ends set; inputs are (data, input_starts, input_axes)
if(inputs[1].lens().at(0) != ends.size())
{
MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch");
}
std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
return shape::dynamic_dimension{0, dd.max};
});
}
else if(set_attributes == starts_only)
{ {
// if axes is an input, then all the output dimensions could be 0 to the max value // attr starts set; inputs are (data, input_ends, input_axes)
if(inputs[1].lens().at(0) != starts.size())
{
MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch");
}
std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) { std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
return shape::dynamic_dimension{0, dd.max}; return shape::dynamic_dimension{0, dd.max};
}); });
} }
return shape{input_shape.type(), dds}; else
{
MIGRAPHX_THROW("Invalid 3 input and attributes configuration");
}
}
else
{
// all 4 inputs (data, inputs_starts, input_ends, input_axes)
std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
return shape::dynamic_dimension{0, dd.max};
});
}
return shape{input_shape.type(), dds};
}
// uses the normalize_axes flag to normalize axes, starts, and ends
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1, 2, 3, 4);
if(inputs.size() == 1)
{
auto input_shape = inputs[0];
auto set_attributes = get_set_attributes();
if(set_attributes != all_set)
{
MIGRAPHX_THROW("SLICE 1_arg: Invalid 1 input and attributes configuration");
}
// NOTE: make sure to update how normalization works here if this type of slicing is
// changed to be allowed
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{
MIGRAPHX_THROW(
"SLICE 1_arg: slicing is not allowed on non-fixed dynamic input axis ");
}
if(input_shape.dynamic())
{
return shape{
input_shape.type(),
lens_calc(input_shape.min_lens(), this->starts, this->ends, this->axes),
lens_calc(input_shape.max_lens(), this->starts, this->ends, this->axes),
{}};
}
else
{
return shape{input_shape.type(),
lens_calc(input_shape.lens(), this->starts, this->ends, this->axes),
input_shape.strides()};
}
}
else
{
return compute_two_or_more(inputs);
} }
} }
...@@ -194,14 +319,14 @@ struct slice ...@@ -194,14 +319,14 @@ struct slice
/** /**
* Calculates the starting offset for the sliced tensor (for aliasing). * Calculates the starting offset for the sliced tensor (for aliasing).
* Used when the starts and/or the axes are inputs. * Used for 2-4 inputs to `slice.
* *
* \param s static input shape * \param s static input shape
* \param input_starts starting indices of slice * \param input_starts starting indices of slice
* \param ax_vec axes to slice on * \param ax_vec axes to slice on
*/ */
template <class IndView, class Axes> template <class T>
auto compute_offset(const shape& s, const IndView& input_starts, const Axes& ax_vec) const auto compute_offset(const shape& s, const T& input_starts, const T& ax_vec) const
{ {
auto ret = 0; auto ret = 0;
for(std::size_t i = 0; i < ax_vec.size(); ++i) for(std::size_t i = 0; i < ax_vec.size(); ++i)
...@@ -212,106 +337,168 @@ struct slice ...@@ -212,106 +337,168 @@ struct slice
return ret * s.type_size(); return ret * s.type_size();
} }
std::unordered_map<std::string, std::vector<int64_t>>
normalize_inputs(const shape& input_shape,
const std::vector<int64_t>& input_starts,
const std::vector<int64_t>& input_ends) const
{
auto attrs = this->attributes().at("normalize_axes");
return {{"input_starts",
normalize_indices(input_starts,
this->axes,
input_shape,
attrs.at("starts"),
"Slice variable input_starts")},
{"input_ends",
normalize_indices(input_ends,
this->axes,
input_shape,
attrs.at("ends"),
"Slice variable input_ends")}};
}
/** /**
* Three input version of the normalize_inputs. * If given, normalize the inputs. Otherwise get from operator attributes.
* This one also checks that the input_axes are valid. * Return the values in a map.
*
* Parameters
* input_shape: static shape of the input
* input_starts: optional
* input_ends: optional
* input_ends: optional
*/ */
std::unordered_map<std::string, std::vector<int64_t>> std::unordered_map<std::string, std::vector<int64_t>>
normalize_inputs(shape input_shape, normalize_starts_ends_axes(shape input_shape,
const std::vector<int64_t>& input_starts, const optional<std::vector<int64_t>>& input_starts,
const std::vector<int64_t>& input_ends, const optional<std::vector<int64_t>>& input_ends,
const std::vector<int64_t>& input_axes) const const optional<std::vector<int64_t>>& input_axes) const
{ {
auto attrs = this->attributes().at("normalize_axes"); auto axes_attrs = this->attributes().at("normalize_axes");
auto norm_axes = std::vector<int64_t> norm_starts;
normalize_axes(input_axes, input_shape, attrs.at("axes"), "Slice variable input_axes"); std::vector<int64_t> norm_ends;
return {{"input_starts", std::vector<int64_t> norm_axes;
normalize_indices(input_starts, if(input_axes)
norm_axes, {
input_shape, norm_axes = normalize_axes(input_axes.value(),
attrs.at("starts"), input_shape,
"Slice variable input_starts")}, axes_attrs.at("axes"),
{"input_ends", "Slice variable input_axes");
normalize_indices(input_ends, }
norm_axes, else
input_shape, {
attrs.at("ends"), norm_axes = this->axes;
"Slice variable input ends")}, }
{"input_axes", norm_axes}}; if(input_starts)
{
norm_starts = normalize_indices(input_starts.value(),
norm_axes,
input_shape,
axes_attrs.at("starts"),
"Slice variable input_starts");
}
else
{
norm_starts = this->starts;
}
if(input_ends)
{
norm_ends = normalize_indices(input_ends.value(),
norm_axes,
input_shape,
axes_attrs.at("ends"),
"Slice variable input ends");
}
else
{
norm_ends = this->ends;
}
return {{"norm_starts", norm_starts}, {"norm_ends", norm_ends}, {"norm_axes", norm_axes}};
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto input_shape = input.get_shape(); auto input_shape = input.get_shape();
switch(args.size()) if(args.size() == 1)
{ {
case 1: {
std::size_t offset = compute_offset(input_shape); std::size_t offset = compute_offset(input_shape);
return {dyn_out.computed_shape, [=] { return input.data() + offset; }}; return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
} }
case 3: { else
shape calc_shape; {
std::size_t offset = 0; // Note that we re-normalize both the attributes and inputs because of the non-fixed
visit_all(args[1], args[2])([&](auto input_starts, auto input_ends) { // dynamic input shape case. It's possible to only re-normalize if slicing over
auto norm_inputs = normalize_inputs(input_shape, // non-fixed dynamic_dimensions.
input_starts.template to_vector<int64_t>(), auto set_attributes = get_set_attributes();
input_ends.template to_vector<int64_t>()); std::unordered_map<std::string, std::vector<int64_t>> norm_inputs;
offset = compute_offset(input_shape, norm_inputs.at("input_starts"), this->axes); if(set_attributes == ends_axes)
calc_shape = {input_shape.type(), {
lens_calc(input_shape.lens(), // attr ends and axes set; inputs are (data, input_starts)
norm_inputs.at("input_starts"), args[1].visit([&](auto input_starts) {
norm_inputs.at("input_ends"), norm_inputs =
this->axes), normalize_starts_ends_axes(input_shape,
input_shape.strides()}; input_starts.template to_vector<int64_t>(),
}); this->ends,
return {calc_shape, [=] { return input.data() + offset; }}; this->axes);
} });
case 4: { }
shape calc_shape; else if(set_attributes == starts_axes)
std::size_t offset = 0; {
visit_all(args[1], args[2], args[3])( // attr starts and axes set; inputs are (data, input_ends)
[&](auto input_starts, auto input_ends, auto input_axes) { args[1].visit([&](auto input_ends) {
auto norm_inputs = normalize_inputs(input_shape, norm_inputs =
input_starts.template to_vector<int64_t>(), normalize_starts_ends_axes(input_shape,
input_ends.template to_vector<int64_t>(), this->starts,
input_axes.template to_vector<int64_t>()); input_ends.template to_vector<int64_t>(),
offset = compute_offset( this->axes);
input_shape, norm_inputs.at("input_starts"), norm_inputs.at("input_axes")); });
calc_shape = shape{input_shape.type(), }
lens_calc(input_shape.lens(), else if(set_attributes == starts_ends)
norm_inputs.at("input_starts"), {
norm_inputs.at("input_ends"), // attr starts and ends set; inputs are (data, input_axes)
norm_inputs.at("input_axes")), args[1].visit([&](auto input_axes) {
input_shape.strides()}; norm_inputs =
normalize_starts_ends_axes(input_shape,
this->starts,
this->ends,
input_axes.template to_vector<int64_t>());
}); });
}
else if(set_attributes == axes_only)
{
// attr axes set; inputs are (data, input_starts, input_ends)
visit_all(args[1], args[2])([&](auto input_starts, auto input_ends) {
norm_inputs =
normalize_starts_ends_axes(input_shape,
input_starts.template to_vector<int64_t>(),
input_ends.template to_vector<int64_t>(),
this->axes);
});
}
else if(set_attributes == ends_only)
{
// attr ends set; inputs are (data, input_starts, input_axes)
visit_all(args[1], args[2])([&](auto input_starts, auto input_axes) {
norm_inputs =
normalize_starts_ends_axes(input_shape,
input_starts.template to_vector<int64_t>(),
this->ends,
input_axes.template to_vector<int64_t>());
});
}
else if(set_attributes == starts_only)
{
// attr starts set; inputs are (data, input_ends, input_axes)
visit_all(args[1], args[2])([&](auto input_ends, auto input_axes) {
norm_inputs =
normalize_starts_ends_axes(input_shape,
this->starts,
input_ends.template to_vector<int64_t>(),
input_axes.template to_vector<int64_t>());
});
}
else
{
// no attr set, all inputs
visit_all(args[1], args[2], args[3])(
[&](auto input_starts, auto input_ends, auto input_axes) {
norm_inputs =
normalize_starts_ends_axes(input_shape,
input_starts.template to_vector<int64_t>(),
input_ends.template to_vector<int64_t>(),
input_axes.template to_vector<int64_t>());
});
}
auto offset = compute_offset(
input_shape, norm_inputs.at("norm_starts"), norm_inputs.at("norm_axes"));
shape calc_shape = shape{input_shape.type(),
lens_calc(input_shape.lens(),
norm_inputs.at("norm_starts"),
norm_inputs.at("norm_ends"),
norm_inputs.at("norm_axes")),
input_shape.strides()};
return {calc_shape, [=] { return input.data() + offset; }}; return {calc_shape, [=] { return input.data() + offset; }};
} }
default: {
// Should never get here; covering in case some code change occurs
MIGRAPHX_THROW("SLICE: invalid number of inputs");
}
}
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -84,6 +84,7 @@ ...@@ -84,6 +84,7 @@
#include <migraphx/op/mod.hpp> #include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/nearbyint.hpp>
#include <migraphx/op/neg.hpp> #include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp> #include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp> #include <migraphx/op/nonzero.hpp>
...@@ -110,7 +111,6 @@ ...@@ -110,7 +111,6 @@
#include <migraphx/op/rnn_variable_seq_lens.hpp> #include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp> #include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/roialign.hpp> #include <migraphx/op/roialign.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter_add.hpp> #include <migraphx/op/scatter_add.hpp>
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape ...@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
m(int32_type, int32_t) \ m(int32_type, int32_t) \
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on // clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
...@@ -28,25 +28,35 @@ ...@@ -28,25 +28,35 @@
#include <type_traits> #include <type_traits>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \
template <class X> \
struct trait : std::trait<X> \
{ \
};
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ #define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \ template <> \
struct trait<T> : std::true_type \ struct trait<T> : std::true_type \
{ \ { \
}; };
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
template <class T> template <class T>
using accumulator_type = using accumulator_type =
std::conditional_t<is_floating_point<T>{}, std::conditional_t<is_floating_point<T>{},
......
...@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if(std::any_of(axes.begin(), axes.end(), [&](auto ax) {
return not input_shape.dyn_dims().at(ax).is_fixed();
}))
{
return vec;
}
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i); return input_shape.dyn_dims().at(i).max;
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
}); });
} }
else else
......
...@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED) ...@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED)
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto) protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto)
add_library(onnx-proto STATIC ${PROTO_SRCS}) add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR}) target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(onnx-proto PRIVATE -w) if(MSVC)
target_compile_options(onnx-proto PRIVATE /w)
else()
target_compile_options(onnx-proto PRIVATE -w)
endif()
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
...@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) ...@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
migraphx_generate_export_header(migraphx_onnx) migraphx_generate_export_header(migraphx_onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL") target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
if(NOT WIN32)
target_link_libraries(migraphx_onnx PRIVATE "-Wl,--exclude-libs,ALL")
endif()
target_link_libraries(migraphx_onnx PUBLIC migraphx) target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets( rocm_install_targets(
......
...@@ -97,10 +97,11 @@ struct onnx_parser ...@@ -97,10 +97,11 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t limit_max_iterations = std::numeric_limits<uint16_t>::max();
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
......
...@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
} }
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
parser.limit_max_iterations = options.limit_max_iterations;
parser.use_dyn_output = options.use_dyn_output; parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error) if(options.print_program_on_error)
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Neg", "neg"}, {"Neg", "neg"},
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "nearbyint"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_isinf : op_parser<parse_isinf>
{
std::vector<op_desc> operators() const { return {{"IsInf", "isinf"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
bool detect_negative = true;
bool detect_positive = true;
if(contains(info.attributes, "detect_negative"))
{
detect_negative = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_negative")).at<int>());
}
if(contains(info.attributes, "detect_positive"))
{
detect_positive = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_positive")).at<int>());
}
auto x_shape = args[0]->get_shape();
if(not detect_negative and not detect_positive)
{
return info.add_instruction(
make_op("multibroadcast", {{"out_lens", x_shape.lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{shape::bool_type}, {false}}));
}
auto is_inf = info.add_instruction(make_op("isinf"), args[0]);
if(detect_negative and detect_positive)
{
return is_inf;
}
auto zero_l = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto mb_zero =
info.add_instruction(make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), zero_l);
auto cond = info.add_broadcastable_binary_op(
detect_negative ? "less" : "greater", args[0], mb_zero);
if(cond->get_shape().type() != shape::bool_type)
{
cond =
info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), cond);
}
return info.add_instruction(make_op("logical_and"), is_inf, cond);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop> ...@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop>
} }
} }
// cap max_iter because loop uses static shapes with max_iter size and huge numbers
// here can cause overflow
if(max_iterations > parser.limit_max_iterations)
{
std::cerr << "WARNING: PARSE_LOOP max_iterations exceeds the maximum loop "
"iterations limit, it will be changed from "
<< max_iterations << " to " << parser.limit_max_iterations << ".\n";
max_iterations = parser.limit_max_iterations;
}
// condition input is empty // condition input is empty
if(args.at(1)->name() == "undefined") if(args.at(1)->name() == "undefined")
{ {
......
...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv ...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
} }
} }
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm> struct parse_lstm : op_parser<parse_lstm>
{ {
std::vector<op_desc> operators() const { return {{"LSTM"}}; } std::vector<op_desc> operators() const { return {{"LSTM"}}; }
...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>(); input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
} }
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 8) if(args.size() < 8)
{ {
...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins); args.insert(args.end(), 8 - args.size(), ins);
} }
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm", auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output = auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -41,6 +41,9 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -41,6 +41,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
if(args.empty())
MIGRAPHX_THROW("PARSE_MULTINOMIAL: no arguments given");
int dtype = 6; int dtype = 6;
if(contains(info.attributes, "dtype")) if(contains(info.attributes, "dtype"))
dtype = info.attributes.at("dtype").i(); dtype = info.attributes.at("dtype").i();
...@@ -49,35 +52,90 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -49,35 +52,90 @@ struct parse_multinomial : op_parser<parse_multinomial>
size_t sample_size = 1; size_t sample_size = 1;
if(contains(info.attributes, "sample_size")) if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i(); sample_size = info.attributes.at("sample_size").i();
else
MIGRAPHX_THROW("PARSE_MULTINOMIAL: sample_size not given");
// Use logarithmic math to scale probabilities while avoiding division by very
// small numbers. Scaling by the maximum makes very tiny ranges more
// tractable; any constant factor gives equivalent distr. since the Multinomial op.
// normalizes at runtime.
// Subtract the per-batch maximum log-probability, making the per-batch max 0 // Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes = auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]); info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
auto mb_maxes = info.add_instruction( auto cdf = info.add_common_op("sub", args[0], maxes);
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
maxes);
auto cdf = info.add_instruction(migraphx::make_op("sub"), args[0], mb_maxes);
// Take the element-wise exponent to get probabilities in the range (0, 1] // Take the element-wise exponent to get probabilities in the range (0, 1]
cdf = info.add_instruction(migraphx::make_op("exp"), cdf); cdf = info.add_instruction(migraphx::make_op("exp"), cdf);
// Compute the cumulative density function // Compute the cumulative distribution function
cdf = info.add_instruction( cdf = info.add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution instruction_ref seed_input;
std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed")) if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f()); {
float seed = info.attributes.at("seed").f();
migraphx::shape s{migraphx::shape::float_type, {1}};
std::vector<float> data = {seed};
seed_input = info.add_literal(migraphx::literal(s, data));
}
else
{
seed_input = info.add_instruction(migraphx::make_op("random_seed"));
}
instruction_ref randoms;
shape s0 = args[0]->get_shape();
if(s0.dynamic())
{
// Dynamic batch_size will be taken from args[0]. The input argument to this should
// have a second dimension of sample_size.
std::vector<shape::dynamic_dimension> dyn_dim_set;
dyn_dim_set.emplace_back(s0.dyn_dims().front());
dyn_dim_set.emplace_back(shape::dynamic_dimension{sample_size, sample_size});
// read the input dimensions
auto dim_of =
info.add_instruction(migraphx::make_op("dimensions_of", {{"end", 2}}), args[0]);
// The next two operations insert the value sample_size into the second array position
// make an argument of (1, 0)
shape s(shape::int64_type, {2});
std::vector<int64_t> data1{1, 0};
auto l1 = info.add_literal(s, data1);
auto batch_arg = info.add_instruction(migraphx::make_op("mul"), dim_of, l1);
std::vector<int64_t> data2(2, 0);
// make an argument of (0, sample_size)
data2[1] = sample_size;
auto l2 = info.add_literal(s, data2);
auto alloc_shape = info.add_instruction(migraphx::make_op("add"), batch_arg, l2);
// alloc_shape should contain the input-based shape dimensions as its values at runtime,
// and its own shape is {2}
// compile_shape is the shape used when compiling the Allocate op, and may be dynamic
migraphx::shape compile_shape =
migraphx::shape(s0.type(), {s0.dyn_dims().front(), {sample_size, sample_size}});
std::uniform_real_distribution<> dis(0.0, 1.0); // Allocate on-device storage for the random values
size_t batch_size = args[0]->get_shape().lens().front(); auto alloc = info.add_instruction(
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::make_op("allocate", {{"shape", to_value(compile_shape)}}), alloc_shape);
randoms = info.add_instruction(migraphx::make_op("random_uniform"), seed_input, alloc);
}
else
{
// use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size .
size_t batch_size = s0.lens().front();
auto rand_dummy = info.add_literal(
migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}});
std::vector<float> random_dist(batch_size * sample_size); randoms =
std::generate(random_dist.begin(), random_dist.end(), [&]() { return dis(gen); }); info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
auto dist_lit = info.add_literal(migraphx::literal{dist_shape, random_dist}); }
return info.add_instruction( return info.add_instruction(
migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, dist_lit); migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, randoms);
} }
}; };
......
...@@ -36,7 +36,7 @@ namespace onnx { ...@@ -36,7 +36,7 @@ namespace onnx {
/* /*
********************************************************************************* *********************************************************************************
* Reference: see QLinearAdd in * * Reference: see QLinearAdd, QLinearMul in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md * * https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
********************************************************************************* *********************************************************************************
...@@ -49,6 +49,17 @@ namespace onnx { ...@@ -49,6 +49,17 @@ namespace onnx {
This version of the operator has been available since version 1 of the 'com.microsoft' operator This version of the operator has been available since version 1 of the 'com.microsoft' operator
set. set.
com.microsoft.QLinearMul
Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting
support).
C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
General definition of binary QLinear* ops:
Inputs (7 - 8) Inputs (7 - 8)
A : T A : T
First operand. First operand.
...@@ -88,15 +99,18 @@ namespace onnx { ...@@ -88,15 +99,18 @@ namespace onnx {
*/ */
struct parse_qlinearadd : op_parser<parse_qlinearadd> struct parse_qlinearbinary : op_parser<parse_qlinearbinary>
{ {
std::vector<op_desc> operators() const { return {{"QLinearAdd"}}; } std::vector<op_desc> operators() const
{
return {{"QLinearAdd", "add"}, {"QLinearMul", "mul"}};
}
// basic type checking for QLinearAdd Operator // basic type checking for binary QLinear Operator
void check_inputs(const std::vector<instruction_ref>& args) const void check_inputs(const std::vector<instruction_ref>& args, const std::string& op_name) const
{ {
if(args.size() < 7) if(args.size() < 7)
MIGRAPHX_THROW("QLINEARADD: missing inputs"); MIGRAPHX_THROW(op_name + ": missing inputs");
const auto& in_a = args[0]; const auto& in_a = args[0];
const auto& in_b = args[3]; const auto& in_b = args[3];
...@@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd> ...@@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
auto type_a = sh_a.type(); auto type_a = sh_a.type();
auto type_b = sh_b.type(); auto type_b = sh_b.type();
if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type) if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type"); MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type) if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type"); MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_a != type_b) if(type_a != type_b)
MIGRAPHX_THROW("QLINEARADD: mismatched input types"); MIGRAPHX_THROW(op_name + ": mismatched input types");
} }
instruction_ref parse(const op_desc& /* opd */, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
check_inputs(args); check_inputs(args, opd.op_name);
// A // A
const auto& in_a = args[0]; const auto& in_a = args[0];
...@@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd> ...@@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
const auto& in_zero_pt_b = args[5]; const auto& in_zero_pt_b = args[5];
auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info); auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);
// C = A + B // C = op(A, B)
auto out_c = info.add_common_op("add", dquant_a, dquant_b); auto out_c = info.add_common_op(opd.op_name, dquant_a, dquant_b);
const auto& in_scale_c = args[6]; const auto& in_scale_c = args[6];
......
...@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr) ...@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
return nearest_mode; return nearest_mode;
} }
static std::vector<double> get_scales(const onnx_parser::attribute_map& attr)
{
std::vector<double> scales;
if(contains(attr, "scales"))
{
copy(attr.at("scales").floats(), std::back_inserter(scales));
}
return scales;
}
static void parse_args(const std::vector<instruction_ref>& args,
const std::vector<size_t>& in_lens,
const std::string& op_name,
std::vector<double>& vec_scale,
std::vector<std::size_t>& out_lens)
{
for(const auto& arg : args)
{
if(arg->name() == "undefined" or arg == args.front())
{
continue;
}
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + op_name +
": specified output size does not match input size");
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
// scale input
if(lens[0] == in_lens.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_" + op_name + ": dynamic input scale is not supported!");
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
}
}
}
}
struct parse_resize : op_parser<parse_resize> struct parse_resize : op_parser<parse_resize>
{ {
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; } std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
...@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize> ...@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize>
std::vector<std::size_t> out_lens(in_lens.size()); std::vector<std::size_t> out_lens(in_lens.size());
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale = get_scales(info.attributes);
for(const auto& arg : args) // If `scales` was not an attribute, it must be an input
if(vec_scale.empty())
{ {
if(arg->name() == "undefined" or arg == args.front()) // Depending on the args, it *must* populate the `vec_scale`, and might populate
{ // `out_lens`
continue; parse_args(args, in_lens, opd.op_name, vec_scale, out_lens);
} }
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
}
// compute the scale if(in_lens.size() != vec_scale.size())
vec_scale.resize(in_lens.size()); {
std::transform(in_lens.begin(), MIGRAPHX_THROW("PARSE_" + opd.op_name + ": ranks of input and scale are different!");
in_lens.end(), }
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
// scale input // if the output was not calculated yet, we update it based on the scales
if(lens[0] == in_lens.size()) if(all_of(out_lens.cbegin(), out_lens.cend(), [](auto o) { return o == 0; }))
{ {
auto arg_scale = arg->eval(); std::transform(
check_arg_empty(arg_scale, in_lens.begin(),
"PARSE_" + opd.op_name + in_lens.end(),
": dynamic input scale is not supported!"); vec_scale.begin(),
out_lens.begin(),
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); }); [&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
}
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale);
});
}
}
} }
shape out_s{in_s.type(), out_lens}; shape out_s{in_s.type(), out_lens};
...@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize> ...@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
// reshape input to one-dimension // reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())}; std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]); auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
if(mode == "nearest") if(mode == "nearest")
......
...@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice> ...@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); } void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); }
/**
* Either insert argument into `this->op_args` or return the constant value of the argument
*/
std::vector<int64_t> insert(instruction_ref arg) std::vector<int64_t> insert(instruction_ref arg)
{ {
std::vector<int64_t> result; std::vector<int64_t> result;
...@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice> ...@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice>
sd.op.axes = axes; sd.op.axes = axes;
} }
if(not sd.steps.empty()) if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return s != 1; }))
{ {
if(sd.op.starts.empty() or sd.op.ends.empty()) if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported"); MIGRAPHX_THROW(
"PARSE_SLICE: steps and variable starts and/or ends is not supported");
if(sd.op.axes.empty()) if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported"); MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
} }
assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
// If any axes have negative step, prepare to add a "reverse" op // If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(sd.steps.size())) for(auto i : range(sd.steps.size()))
{ {
......
...@@ -68,13 +68,34 @@ struct parse_split : op_parser<parse_split> ...@@ -68,13 +68,34 @@ struct parse_split : op_parser<parse_split>
// no split attribute, input is equally divided // no split attribute, input is equally divided
else else
{ {
if((lens[tuned_axis] % info.num_outputs) != 0) std::size_t num_outputs = info.num_outputs;
// the num_outputs attribute seems to be redundant since we already have
// node_info::num_outputs, but we can still perform an error check
if(contains(info.attributes, "num_outputs"))
{ {
MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " + num_outputs =
std::to_string(info.num_outputs) + " splits!"); parser.parse_value(info.attributes.at("num_outputs")).at<std::size_t>();
if(num_outputs != info.num_outputs)
{
MIGRAPHX_THROW("PARSE_SPLIT: num_outputs attribute " +
std::to_string(num_outputs) +
" doesn't match actual number of outputs " +
std::to_string(info.num_outputs) + "!");
}
}
if(lens[tuned_axis] % num_outputs == 0)
{
std::size_t chunk_size = lens[tuned_axis] / num_outputs;
vec_splits.resize(num_outputs, chunk_size);
}
else
{
std::size_t chunk_size = lens[tuned_axis] / num_outputs + 1;
std::size_t last_chunk_size = lens[tuned_axis] - chunk_size * (num_outputs - 1);
vec_splits.resize(num_outputs - 1, chunk_size);
vec_splits.push_back(last_chunk_size);
} }
auto dl = lens[tuned_axis] / info.num_outputs;
vec_splits.resize(info.num_outputs, dl);
} }
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) != if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/float8.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#endif #endif
...@@ -144,6 +144,18 @@ struct npy_format_descriptor<half> ...@@ -144,6 +144,18 @@ struct npy_format_descriptor<half>
static constexpr auto name() { return _("half"); } static constexpr auto name() { return _("half"); }
}; };
template <>
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
{
static std::string format()
{
// following: https://docs.python.org/3/library/struct.html#format-characters
// TODO: need to figure out correct encoding
return "z";
}
static constexpr auto name() { return _("fp8e4m3fnuz"); }
};
} // namespace detail } // namespace detail
} // namespace pybind11 } // namespace pybind11
...@@ -472,7 +484,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -472,7 +484,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
map_dyn_input_dims, map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error, bool print_program_on_error,
int64_t max_loop_iterations) { int64_t max_loop_iterations,
int64_t limit_max_iterations) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value; options.default_dyn_dim_value = default_dyn_dim_value;
...@@ -481,6 +494,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -481,6 +494,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations; options.max_loop_iterations = max_loop_iterations;
options.limit_max_iterations = limit_max_iterations;
return migraphx::parse_onnx(filename, options); return migraphx::parse_onnx(filename, options);
}, },
"Parse onnx file", "Parse onnx file",
...@@ -492,7 +506,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -492,7 +506,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(), std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10); py::arg("max_loop_iterations") = 10,
py::arg("limit_max_iterations") = std::numeric_limits<uint16_t>::max());
m.def( m.def(
"parse_onnx_buffer", "parse_onnx_buffer",
......
...@@ -56,7 +56,11 @@ target make_target(const std::string& name) ...@@ -56,7 +56,11 @@ target make_target(const std::string& name)
{ {
if(not contains(target_map(), name)) if(not contains(target_map(), name))
{ {
#ifdef _WIN32
std::string target_name = "migraphx_" + name + ".dll";
#else
std::string target_name = "libmigraphx_" + name + ".so"; std::string target_name = "libmigraphx_" + name + ".so";
#endif
store_target_lib(dynamic_loader(target_name)); store_target_lib(dynamic_loader(target_name));
} }
const auto it = target_map().find(name); const auto it = target_map().find(name);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment