"benchmark/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "1fb76ebb93894447d1e3a0a04bca4875563b328c"
Unverified Commit 52585d4f authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents f0370072 d8011adf
...@@ -40,6 +40,8 @@ namespace op { ...@@ -40,6 +40,8 @@ namespace op {
* 2. use_rank (default) vs use_len: * 2. use_rank (default) vs use_len:
* `use_rank` sets the max value/index of the attribute as the rank of lens. * `use_rank` sets the max value/index of the attribute as the rank of lens.
* `use_lens` sets the max value/index as the corresponding value in lens at the axes index. * `use_lens` sets the max value/index as the corresponding value in lens at the axes index.
* Uses the dynamic_dimension.max value for dynamic shapes. Returns the original vector
* (no normalization) if any of dynamic_dimension[axes] are not fixed.
* 3. `clip_min` vs. `not_clip_min` (default): * 3. `clip_min` vs. `not_clip_min` (default):
* Clip values less than the minimum to the minimum or not. * Clip values less than the minimum to the minimum or not.
* 4. `include_min` vs. `exclude_min` (default): * 4. `include_min` vs. `exclude_min` (default):
......
...@@ -30,11 +30,11 @@ ...@@ -30,11 +30,11 @@
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <fenv.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct quantizelinear struct quantizelinear
{ {
std::string name() const { return "quantizelinear"; } std::string name() const { return "quantizelinear"; }
...@@ -71,26 +71,26 @@ struct quantizelinear ...@@ -71,26 +71,26 @@ struct quantizelinear
{ {
y_zero_point = args.at(2); y_zero_point = args.at(2);
} }
argument result{output_shape}; argument result{output_shape};
auto rounding_mode = fegetround();
fesetround(FE_TONEAREST);
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) { visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
visit_all(x, y_scale)([&](auto input, auto scales) { visit_all(x, y_scale)([&](auto input, auto scales) {
using quant_type = typename decltype(output)::value_type; using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min(); auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max(); auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) + int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]); static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value), output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized)); std::min(static_cast<int64_t>(max_value), quantized));
}); });
}); });
}); });
fesetround(rounding_mode);
return result; return result;
} }
}; };
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -38,6 +38,18 @@ namespace op { ...@@ -38,6 +38,18 @@ namespace op {
/** /**
* Slice operator that accepts variable axes, starts and ends. * Slice operator that accepts variable axes, starts and ends.
* All of `starts`, `ends`, and `axes` must be supplied by either
* their attribute or an input (but not both).
*
* Valid calls:
* slice(input); axes, starts, ends set
* slice(input, starts); axes, ends set
* slice(input, ends); starts, axes set
* slice(input, axes); starts, ends set
* slice(input, starts, ends); axes set
* slice(input, starts, axes); ends set
* slice(input, ends, axes); starts set
* slice(input, start, ends, axes); none set
* *
* Attributes: * Attributes:
* axes: constant axes to slice over (optional) * axes: constant axes to slice over (optional)
...@@ -46,8 +58,8 @@ namespace op { ...@@ -46,8 +58,8 @@ namespace op {
* *
* Parameters: * Parameters:
* data: the input tensor to slice (dynamic or static shape) * data: the input tensor to slice (dynamic or static shape)
* input_starts: starting indicies of slice (optional, static shape) * input_starts: starting indices of slice (optional, static shape)
* input_ends: ending indicies of slice (optional, static shape) * input_ends: ending indices of slice (optional, static shape)
* input_axes: axes to slice over (optional, static shape) * input_axes: axes to slice over (optional, static shape)
*/ */
struct slice struct slice
...@@ -56,6 +68,18 @@ struct slice ...@@ -56,6 +68,18 @@ struct slice
std::vector<int64_t> starts{}; std::vector<int64_t> starts{};
std::vector<int64_t> ends{}; std::vector<int64_t> ends{};
/**
* Named arrays for the set attribute possibilities.
*/
static constexpr std::array<bool, 3> all_set = {true, true, true};
static constexpr std::array<bool, 3> ends_axes = {false, true, true};
static constexpr std::array<bool, 3> starts_axes = {true, false, true};
static constexpr std::array<bool, 3> starts_ends = {true, true, false};
static constexpr std::array<bool, 3> axes_only = {false, false, true};
static constexpr std::array<bool, 3> ends_only = {false, true, false};
static constexpr std::array<bool, 3> starts_only = {true, false, false};
static constexpr std::array<bool, 3> none_set = {false, false, false};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -63,24 +87,26 @@ struct slice ...@@ -63,24 +87,26 @@ struct slice
} }
/** /**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are * Ensure that attribute axes is within limits.
* within limits. * Will attempt to normalize starts and ends; but will use the dynamic_dimension.max
* values for dynamic shapes. This makes it so you have to renormalize for
* non-fixed dynamic_dimensions.
*/ */
value attributes() const value attributes() const
{ {
value normalize = value::object{}; value normalize_axes = value::object{};
normalize["axes"] = value::array{normalize_attribute::include_min}; normalize_axes["axes"] = value::array{normalize_attribute::include_min};
normalize["starts"] = value::array{normalize_attribute::clip_max, normalize_axes["starts"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min, normalize_attribute::clip_min,
normalize_attribute::include_max, normalize_attribute::include_max,
normalize_attribute::use_len, normalize_attribute::use_len,
normalize_attribute::include_min}; normalize_attribute::include_min};
normalize["ends"] = value::array{normalize_attribute::clip_max, normalize_axes["ends"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min, normalize_attribute::clip_min,
normalize_attribute::include_max, normalize_attribute::include_max,
normalize_attribute::use_len, normalize_attribute::use_len,
normalize_attribute::include_min}; normalize_attribute::include_min};
return {{"normalize_axes", normalize}}; return {{"normalize_axes", normalize_axes}};
} }
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
...@@ -88,7 +114,7 @@ struct slice ...@@ -88,7 +114,7 @@ struct slice
/** /**
* Computes the slice output shape dimensions for given starts, ends,and axes. * Computes the slice output shape dimensions for given starts, ends,and axes.
* Templated to also handle tensor views. * Templated to also handle tensor views.
* Possibily different type between [in_starts, in_ends] and [in_axes] if in_axes is this * Possibly different type between [in_starts, in_ends] and [in_axes] if in_axes is this
* object's axes attribute. Assumes in_starts and in_ends are normalized; in_axes are valid. * object's axes attribute. Assumes in_starts and in_ends are normalized; in_axes are valid.
*/ */
template <class A, class B> template <class A, class B>
...@@ -104,62 +130,160 @@ struct slice ...@@ -104,62 +130,160 @@ struct slice
return new_lens; return new_lens;
} }
shape normalize_compute_shape(std::vector<shape> inputs) const /// Get the attributes that are non-empty
std::array<bool, 3> get_set_attributes() const
{ {
check_shapes{inputs, *this, true}.has(1, 3, 4); std::array<std::vector<int64_t>, 3> attrs = {this->starts, this->ends, this->axes};
auto input_shape = inputs[0]; std::array<bool, 3> bool_vec;
if(inputs.size() == 1) std::transform(
attrs.cbegin(), attrs.cend(), bool_vec.begin(), [](auto a) { return not a.empty(); });
return bool_vec;
}
/// Helper function for normalize_compute_shape()
shape compute_two_or_more(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto set_attributes = get_set_attributes();
// check that inputs [1, end) are all 1D, have the same
// dimension, and are static
check_shapes{inputs.begin() + 1,
inputs.end(),
std::string("SLICE: inputs (starts, ends, and input_axes)"),
false}
.only_dims(1)
.same_dims();
auto dds = input_shape.to_dynamic().dyn_dims();
if(inputs.size() == 2)
{ {
auto t = input_shape.type(); if(set_attributes == ends_axes)
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{ {
MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis "); // attr ends and axes set; inputs are (data, input_starts)
if(inputs[1].lens().at(0) != axes.size())
{
MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch");
}
std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) {
dds.at(axis) = {0, dds.at(axis).max};
});
} }
if(input_shape.dynamic()) else if(set_attributes == starts_axes)
{ {
return shape{t, // attr starts and axes set; inputs are (data, input_ends)
lens_calc(input_shape.min_lens(), starts, ends, axes), if(inputs[1].lens().at(0) != axes.size())
lens_calc(input_shape.max_lens(), starts, ends, axes), {
{}}; MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch");
}
std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) {
dds.at(axis) = {0, dds.at(axis).max};
});
}
else if(set_attributes == starts_ends)
{
// attr starts and ends set; inputs are (data, input_axes)
if(inputs[1].lens().at(0) != starts.size())
{
MIGRAPHX_THROW("SLICE: 2 input and attributes mismatch");
}
std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
return shape::dynamic_dimension{0, dd.max};
});
} }
else else
{ {
return shape{ MIGRAPHX_THROW("SLICE: Invalid 2 input and attributes configuration");
t, lens_calc(input_shape.lens(), starts, ends, axes), input_shape.strides()};
} }
} }
else else if(inputs.size() == 3)
{ {
// check that starts, ends, and optionally input_axes are all 1D, have the same if(set_attributes == axes_only)
// dimension, and are static
check_shapes{inputs.begin() + 1,
inputs.end(),
std::string("SLICE: inputs (starts, ends, and input_axes)"),
false}
.only_dims(1)
.same_dims();
auto dds = input_shape.to_dynamic().dyn_dims();
if(inputs.size() == 3)
{ {
// attr axes set; inputs are (data, input_starts, input_ends)
if(inputs[1].lens().at(0) != axes.size()) if(inputs[1].lens().at(0) != axes.size())
{ {
MIGRAPHX_THROW("SLICE: inputs starts and ends do not have the same dimension " MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch");
"as the axes attribute");
} }
std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) {
dds.at(axis) = {0, dds.at(axis).max}; dds.at(axis) = {0, dds.at(axis).max};
}); });
} }
else else if(set_attributes == ends_only)
{
// attr ends set; inputs are (data, input_starts, input_axes)
if(inputs[1].lens().at(0) != ends.size())
{
MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch");
}
std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
return shape::dynamic_dimension{0, dd.max};
});
}
else if(set_attributes == starts_only)
{ {
// if axes is an input, then all the output dimensions could be 0 to the max value // attr starts set; inputs are (data, input_ends, input_axes)
if(inputs[1].lens().at(0) != starts.size())
{
MIGRAPHX_THROW("SLICE: 3 input and attributes mismatch");
}
std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) { std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
return shape::dynamic_dimension{0, dd.max}; return shape::dynamic_dimension{0, dd.max};
}); });
} }
return shape{input_shape.type(), dds}; else
{
MIGRAPHX_THROW("Invalid 3 input and attributes configuration");
}
}
else
{
// all 4 inputs (data, inputs_starts, input_ends, input_axes)
std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
return shape::dynamic_dimension{0, dd.max};
});
}
return shape{input_shape.type(), dds};
}
// uses the normalize_axes flag to normalize axes, starts, and ends
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1, 2, 3, 4);
if(inputs.size() == 1)
{
auto input_shape = inputs[0];
auto set_attributes = get_set_attributes();
if(set_attributes != all_set)
{
MIGRAPHX_THROW("SLICE 1_arg: Invalid 1 input and attributes configuration");
}
// NOTE: make sure to update how normalization works here if this type of slicing is
// changed to be allowed
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{
MIGRAPHX_THROW(
"SLICE 1_arg: slicing is not allowed on non-fixed dynamic input axis ");
}
if(input_shape.dynamic())
{
return shape{
input_shape.type(),
lens_calc(input_shape.min_lens(), this->starts, this->ends, this->axes),
lens_calc(input_shape.max_lens(), this->starts, this->ends, this->axes),
{}};
}
else
{
return shape{input_shape.type(),
lens_calc(input_shape.lens(), this->starts, this->ends, this->axes),
input_shape.strides()};
}
}
else
{
return compute_two_or_more(inputs);
} }
} }
...@@ -194,14 +318,14 @@ struct slice ...@@ -194,14 +318,14 @@ struct slice
/** /**
* Calculates the starting offset for the sliced tensor (for aliasing). * Calculates the starting offset for the sliced tensor (for aliasing).
* Used when the starts and/or the axes are inputs. * Used for 2-4 inputs to `slice.
* *
* \param s static input shape * \param s static input shape
* \param input_starts starting indices of slice * \param input_starts starting indices of slice
* \param ax_vec axes to slice on * \param ax_vec axes to slice on
*/ */
template <class IndView, class Axes> template <class T>
auto compute_offset(const shape& s, const IndView& input_starts, const Axes& ax_vec) const auto compute_offset(const shape& s, const T& input_starts, const T& ax_vec) const
{ {
auto ret = 0; auto ret = 0;
for(std::size_t i = 0; i < ax_vec.size(); ++i) for(std::size_t i = 0; i < ax_vec.size(); ++i)
...@@ -212,106 +336,168 @@ struct slice ...@@ -212,106 +336,168 @@ struct slice
return ret * s.type_size(); return ret * s.type_size();
} }
std::unordered_map<std::string, std::vector<int64_t>>
normalize_inputs(const shape& input_shape,
const std::vector<int64_t>& input_starts,
const std::vector<int64_t>& input_ends) const
{
auto attrs = this->attributes().at("normalize_axes");
return {{"input_starts",
normalize_indices(input_starts,
this->axes,
input_shape,
attrs.at("starts"),
"Slice variable input_starts")},
{"input_ends",
normalize_indices(input_ends,
this->axes,
input_shape,
attrs.at("ends"),
"Slice variable input_ends")}};
}
/** /**
* Three input version of the normalize_inputs. * If given, normalize the inputs. Otherwise get from operator attributes.
* This one also checks that the input_axes are valid. * Return the values in a map.
*
* Parameters
* input_shape: static shape of the input
* input_starts: optional
* input_ends: optional
* input_ends: optional
*/ */
std::unordered_map<std::string, std::vector<int64_t>> std::unordered_map<std::string, std::vector<int64_t>>
normalize_inputs(shape input_shape, normalize_starts_ends_axes(shape input_shape,
const std::vector<int64_t>& input_starts, const optional<std::vector<int64_t>>& input_starts,
const std::vector<int64_t>& input_ends, const optional<std::vector<int64_t>>& input_ends,
const std::vector<int64_t>& input_axes) const const optional<std::vector<int64_t>>& input_axes) const
{ {
auto attrs = this->attributes().at("normalize_axes"); auto axes_attrs = this->attributes().at("normalize_axes");
auto norm_axes = std::vector<int64_t> norm_starts;
normalize_axes(input_axes, input_shape, attrs.at("axes"), "Slice variable input_axes"); std::vector<int64_t> norm_ends;
return {{"input_starts", std::vector<int64_t> norm_axes;
normalize_indices(input_starts, if(input_axes)
norm_axes, {
input_shape, norm_axes = normalize_axes(input_axes.value(),
attrs.at("starts"), input_shape,
"Slice variable input_starts")}, axes_attrs.at("axes"),
{"input_ends", "Slice variable input_axes");
normalize_indices(input_ends, }
norm_axes, else
input_shape, {
attrs.at("ends"), norm_axes = this->axes;
"Slice variable input ends")}, }
{"input_axes", norm_axes}}; if(input_starts)
{
norm_starts = normalize_indices(input_starts.value(),
norm_axes,
input_shape,
axes_attrs.at("starts"),
"Slice variable input_starts");
}
else
{
norm_starts = this->starts;
}
if(input_ends)
{
norm_ends = normalize_indices(input_ends.value(),
norm_axes,
input_shape,
axes_attrs.at("ends"),
"Slice variable input ends");
}
else
{
norm_ends = this->ends;
}
return {{"norm_starts", norm_starts}, {"norm_ends", norm_ends}, {"norm_axes", norm_axes}};
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto input_shape = input.get_shape(); auto input_shape = input.get_shape();
switch(args.size()) if(args.size() == 1)
{ {
case 1: {
std::size_t offset = compute_offset(input_shape); std::size_t offset = compute_offset(input_shape);
return {dyn_out.computed_shape, [=] { return input.data() + offset; }}; return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
} }
case 3: { else
shape calc_shape; {
std::size_t offset = 0; // Note that we re-normalize both the attributes and inputs because of the non-fixed
visit_all(args[1], args[2])([&](auto input_starts, auto input_ends) { // dynamic input shape case. It's possible to only re-normalize if slicing over
auto norm_inputs = normalize_inputs(input_shape, // non-fixed dynamic_dimensions.
input_starts.template to_vector<int64_t>(), auto set_attributes = get_set_attributes();
input_ends.template to_vector<int64_t>()); std::unordered_map<std::string, std::vector<int64_t>> norm_inputs;
offset = compute_offset(input_shape, norm_inputs.at("input_starts"), this->axes); if(set_attributes == ends_axes)
calc_shape = {input_shape.type(), {
lens_calc(input_shape.lens(), // attr ends and axes set; inputs are (data, input_starts)
norm_inputs.at("input_starts"), args[1].visit([&](auto input_starts) {
norm_inputs.at("input_ends"), norm_inputs =
this->axes), normalize_starts_ends_axes(input_shape,
input_shape.strides()}; input_starts.template to_vector<int64_t>(),
}); this->ends,
return {calc_shape, [=] { return input.data() + offset; }}; this->axes);
} });
case 4: { }
shape calc_shape; else if(set_attributes == starts_axes)
std::size_t offset = 0; {
visit_all(args[1], args[2], args[3])( // attr starts and axes set; inputs are (data, input_ends)
[&](auto input_starts, auto input_ends, auto input_axes) { args[1].visit([&](auto input_ends) {
auto norm_inputs = normalize_inputs(input_shape, norm_inputs =
input_starts.template to_vector<int64_t>(), normalize_starts_ends_axes(input_shape,
input_ends.template to_vector<int64_t>(), this->starts,
input_axes.template to_vector<int64_t>()); input_ends.template to_vector<int64_t>(),
offset = compute_offset( this->axes);
input_shape, norm_inputs.at("input_starts"), norm_inputs.at("input_axes")); });
calc_shape = shape{input_shape.type(), }
lens_calc(input_shape.lens(), else if(set_attributes == starts_ends)
norm_inputs.at("input_starts"), {
norm_inputs.at("input_ends"), // attr starts and ends set; inputs are (data, input_axes)
norm_inputs.at("input_axes")), args[1].visit([&](auto input_axes) {
input_shape.strides()}; norm_inputs =
normalize_starts_ends_axes(input_shape,
this->starts,
this->ends,
input_axes.template to_vector<int64_t>());
}); });
}
else if(set_attributes == axes_only)
{
// attr axes set; inputs are (data, input_starts, input_ends)
visit_all(args[1], args[2])([&](auto input_starts, auto input_ends) {
norm_inputs =
normalize_starts_ends_axes(input_shape,
input_starts.template to_vector<int64_t>(),
input_ends.template to_vector<int64_t>(),
this->axes);
});
}
else if(set_attributes == ends_only)
{
// attr ends set; inputs are (data, input_starts, input_axes)
visit_all(args[1], args[2])([&](auto input_starts, auto input_axes) {
norm_inputs =
normalize_starts_ends_axes(input_shape,
input_starts.template to_vector<int64_t>(),
this->ends,
input_axes.template to_vector<int64_t>());
});
}
else if(set_attributes == starts_only)
{
// attr starts set; inputs are (data, input_ends, input_axes)
visit_all(args[1], args[2])([&](auto input_ends, auto input_axes) {
norm_inputs =
normalize_starts_ends_axes(input_shape,
this->starts,
input_ends.template to_vector<int64_t>(),
input_axes.template to_vector<int64_t>());
});
}
else
{
// no attr set, all inputs
visit_all(args[1], args[2], args[3])(
[&](auto input_starts, auto input_ends, auto input_axes) {
norm_inputs =
normalize_starts_ends_axes(input_shape,
input_starts.template to_vector<int64_t>(),
input_ends.template to_vector<int64_t>(),
input_axes.template to_vector<int64_t>());
});
}
auto offset = compute_offset(
input_shape, norm_inputs.at("norm_starts"), norm_inputs.at("norm_axes"));
shape calc_shape = shape{input_shape.type(),
lens_calc(input_shape.lens(),
norm_inputs.at("norm_starts"),
norm_inputs.at("norm_ends"),
norm_inputs.at("norm_axes")),
input_shape.strides()};
return {calc_shape, [=] { return input.data() + offset; }}; return {calc_shape, [=] { return input.data() + offset; }};
} }
default: {
// Should never get here; covering in case some code change occurs
MIGRAPHX_THROW("SLICE: invalid number of inputs");
}
}
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -84,6 +84,7 @@ ...@@ -84,6 +84,7 @@
#include <migraphx/op/mod.hpp> #include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/nearbyint.hpp>
#include <migraphx/op/neg.hpp> #include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp> #include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp> #include <migraphx/op/nonzero.hpp>
...@@ -110,7 +111,6 @@ ...@@ -110,7 +111,6 @@
#include <migraphx/op/rnn_variable_seq_lens.hpp> #include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp> #include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/roialign.hpp> #include <migraphx/op/roialign.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter_add.hpp> #include <migraphx/op/scatter_add.hpp>
......
...@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -66,15 +66,15 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if(std::any_of(axes.begin(), axes.end(), [&](auto ax) {
return not input_shape.dyn_dims().at(ax).is_fixed();
}))
{
return vec;
}
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i); return input_shape.dyn_dims().at(i).max;
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
}); });
} }
else else
......
...@@ -97,10 +97,11 @@ struct onnx_parser ...@@ -97,10 +97,11 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t limit_max_iterations = std::numeric_limits<uint16_t>::max();
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
......
...@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
} }
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
parser.limit_max_iterations = options.limit_max_iterations;
parser.use_dyn_output = options.use_dyn_output; parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error) if(options.print_program_on_error)
......
...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Neg", "neg"}, {"Neg", "neg"},
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "nearbyint"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_isinf : op_parser<parse_isinf>
{
std::vector<op_desc> operators() const { return {{"IsInf", "isinf"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
bool detect_negative = true;
bool detect_positive = true;
if(contains(info.attributes, "detect_negative"))
{
detect_negative = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_negative")).at<int>());
}
if(contains(info.attributes, "detect_positive"))
{
detect_positive = static_cast<bool>(
parser.parse_value(info.attributes.at("detect_positive")).at<int>());
}
auto x_shape = args[0]->get_shape();
if(not detect_negative and not detect_positive)
{
return info.add_instruction(
make_op("multibroadcast", {{"out_lens", x_shape.lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{shape::bool_type}, {false}}));
}
auto is_inf = info.add_instruction(make_op("isinf"), args[0]);
if(detect_negative and detect_positive)
{
return is_inf;
}
auto zero_l = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto mb_zero =
info.add_instruction(make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), zero_l);
auto cond = info.add_broadcastable_binary_op(
detect_negative ? "less" : "greater", args[0], mb_zero);
if(cond->get_shape().type() != shape::bool_type)
{
cond =
info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), cond);
}
return info.add_instruction(make_op("logical_and"), is_inf, cond);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop> ...@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop>
} }
} }
// cap max_iter because loop uses static shapes with max_iter size and huge numbers
// here can cause overflow
if(max_iterations > parser.limit_max_iterations)
{
std::cerr << "WARNING: PARSE_LOOP max_iterations exceeds the maximum loop "
"iterations limit, it will be changed from "
<< max_iterations << " to " << parser.limit_max_iterations << ".\n";
max_iterations = parser.limit_max_iterations;
}
// condition input is empty // condition input is empty
if(args.at(1)->name() == "undefined") if(args.at(1)->name() == "undefined")
{ {
......
...@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr) ...@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
return nearest_mode; return nearest_mode;
} }
static std::vector<double> get_scales(const onnx_parser::attribute_map& attr)
{
std::vector<double> scales;
if(contains(attr, "scales"))
{
copy(attr.at("scales").floats(), std::back_inserter(scales));
}
return scales;
}
static void parse_args(const std::vector<instruction_ref>& args,
const std::vector<size_t>& in_lens,
const std::string& op_name,
std::vector<double>& vec_scale,
std::vector<std::size_t>& out_lens)
{
for(const auto& arg : args)
{
if(arg->name() == "undefined" or arg == args.front())
{
continue;
}
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + op_name +
": specified output size does not match input size");
}
// compute the scale
vec_scale.resize(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
// scale input
if(lens[0] == in_lens.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_" + op_name + ": dynamic input scale is not supported!");
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
}
}
}
}
struct parse_resize : op_parser<parse_resize> struct parse_resize : op_parser<parse_resize>
{ {
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; } std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
...@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize> ...@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize>
std::vector<std::size_t> out_lens(in_lens.size()); std::vector<std::size_t> out_lens(in_lens.size());
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale = get_scales(info.attributes);
for(const auto& arg : args) // If `scales` was not an attribute, it must be an input
if(vec_scale.empty())
{ {
if(arg->name() == "undefined" or arg == args.front()) // Depending on the args, it *must* populate the `vec_scale`, and might populate
{ // `out_lens`
continue; parse_args(args, in_lens, opd.op_name, vec_scale, out_lens);
} }
// skipped empty input
auto lens = arg->get_shape().lens();
if(lens.empty())
{
continue;
}
auto type = arg->get_shape().type();
// output size
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
}
// compute the scale if(in_lens.size() != vec_scale.size())
vec_scale.resize(in_lens.size()); {
std::transform(in_lens.begin(), MIGRAPHX_THROW("PARSE_" + opd.op_name + ": ranks of input and scale are different!");
in_lens.end(), }
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return 1.0 * oss / iss; });
}
else
{
// scale input // if the output was not calculated yet, we update it based on the scales
if(lens[0] == in_lens.size()) if(all_of(out_lens.cbegin(), out_lens.cend(), [](auto o) { return o == 0; }))
{ {
auto arg_scale = arg->eval(); std::transform(
check_arg_empty(arg_scale, in_lens.begin(),
"PARSE_" + opd.op_name + in_lens.end(),
": dynamic input scale is not supported!"); vec_scale.begin(),
out_lens.begin(),
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); }); [&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
}
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
return static_cast<std::size_t>(idx * scale);
});
}
}
} }
shape out_s{in_s.type(), out_lens}; shape out_s{in_s.type(), out_lens};
...@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize> ...@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
// reshape input to one-dimension // reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())}; std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]); auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
if(mode == "nearest") if(mode == "nearest")
......
...@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice> ...@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); } void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); }
/**
* Either insert argument into `this->op_args` or return the constant value of the argument
*/
std::vector<int64_t> insert(instruction_ref arg) std::vector<int64_t> insert(instruction_ref arg)
{ {
std::vector<int64_t> result; std::vector<int64_t> result;
...@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice> ...@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice>
sd.op.axes = axes; sd.op.axes = axes;
} }
if(not sd.steps.empty()) if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return s != 1; }))
{ {
if(sd.op.starts.empty() or sd.op.ends.empty()) if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported"); MIGRAPHX_THROW(
"PARSE_SLICE: steps and variable starts and/or ends is not supported");
if(sd.op.axes.empty()) if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported"); MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
} }
assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
// If any axes have negative step, prepare to add a "reverse" op // If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(sd.steps.size())) for(auto i : range(sd.steps.size()))
{ {
......
...@@ -472,7 +472,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -472,7 +472,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
map_dyn_input_dims, map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error, bool print_program_on_error,
int64_t max_loop_iterations) { int64_t max_loop_iterations,
int64_t limit_max_iterations) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value; options.default_dyn_dim_value = default_dyn_dim_value;
...@@ -481,6 +482,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -481,6 +482,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations; options.max_loop_iterations = max_loop_iterations;
options.limit_max_iterations = limit_max_iterations;
return migraphx::parse_onnx(filename, options); return migraphx::parse_onnx(filename, options);
}, },
"Parse onnx file", "Parse onnx file",
...@@ -492,7 +494,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -492,7 +494,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(), std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10); py::arg("max_loop_iterations") = 10,
py::arg("limit_max_iterations") = std::numeric_limits<uint16_t>::max());
m.def( m.def(
"parse_onnx_buffer", "parse_onnx_buffer",
......
...@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x); ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
} }
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/simplify_dyn_ops.hpp> #include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -131,10 +132,53 @@ struct find_const_4in_slice ...@@ -131,10 +132,53 @@ struct find_const_4in_slice
} }
}; };
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches( match::find_matches(m,
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{}); find_static_2in_broadcasts{},
find_static_dimensions_of{},
find_const_3in_slice{},
find_const_4in_slice{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -647,8 +647,8 @@ struct find_broadcast_transpose ...@@ -647,8 +647,8 @@ struct find_broadcast_transpose
{ {
auto transpose = r.result; auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens(); auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"]; auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front(); auto input = bcast_ins->inputs().front();
// scalar transformation does not need extra transpose // scalar transformation does not need extra transpose
if(not input->get_shape().scalar()) if(not input->get_shape().scalar())
{ {
......
# #################################################################################### # ####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -245,10 +245,14 @@ else() ...@@ -245,10 +245,14 @@ else()
endif() endif()
# Check miopen find mode api # Check miopen find mode api
include(CheckLibraryExists) include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION) get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
...@@ -271,6 +275,13 @@ else() ...@@ -271,6 +275,13 @@ else()
message(STATUS "MIOpen does not have find mode api") message(STATUS "MIOpen does not have find mode api")
endif() endif()
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL) if(MIGRAPHX_USE_COMPOSABLEKERNEL)
......
...@@ -168,6 +168,7 @@ struct compile_plan ...@@ -168,6 +168,7 @@ struct compile_plan
} }
const compiled_result& benchmark(problem_cache& pc) const const compiled_result& benchmark(problem_cache& pc) const
{ {
const auto trace_level = value_of(MIGRAPHX_TRACE_BENCHMARKING{});
if(results.empty()) if(results.empty())
MIGRAPHX_THROW("No configs to tune"); MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1) if(results.size() == 1)
...@@ -178,9 +179,10 @@ struct compile_plan ...@@ -178,9 +179,10 @@ struct compile_plan
} }
if(not config) if(not config)
MIGRAPHX_THROW("Multiple kernels without config"); MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" if(trace_level > 0)
<< std::endl; std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{})) << std::endl;
if(trace_level > 1)
std::cout << "Problem: " << config->problem << std::endl; std::cout << "Problem: " << config->problem << std::endl;
std::vector<double> times; std::vector<double> times;
times.reserve(results.size()); times.reserve(results.size());
...@@ -189,22 +191,23 @@ struct compile_plan ...@@ -189,22 +191,23 @@ struct compile_plan
config->solutions.begin(), config->solutions.begin(),
std::back_inserter(times), std::back_inserter(times),
[&](const auto& cr, const auto& solution) { [&](const auto& cr, const auto& solution) {
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{})) if(trace_level > 1)
std::cout << "Benchmarking solution: " << solution << std::endl; std::cout << "Benchmarking solution: " << solution << std::endl;
if(not cr.has_value()) if(not cr.has_value())
{ {
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{})) if(trace_level > 1)
std::cout << "No binary" << std::endl; std::cout << "No binary" << std::endl;
return std::numeric_limits<double>::max(); return std::numeric_limits<double>::max();
} }
auto t = time_op( auto t = time_op(
*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20); *ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20);
if(enabled(MIGRAPHX_TRACE_BENCHMARKING{})) if(trace_level > 1)
std::cout << t << "ms" << std::endl; std::cout << t << "ms" << std::endl;
return t; return t;
}); });
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl; if(trace_level > 0)
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i)); pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value()) if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation."); MIGRAPHX_THROW("No valid tuned compilation.");
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,15 +21,20 @@ ...@@ -21,15 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
using microseconds = std::chrono::duration<double, std::micro>;
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype get_type(shape::type_t type) rocblas_datatype get_type(shape::type_t type)
{ {
switch(type) switch(type)
...@@ -81,184 +86,508 @@ shape transpose_batch(const shape& s, unsigned trans_batch) ...@@ -81,184 +86,508 @@ shape transpose_batch(const shape& s, unsigned trans_batch)
return shape::from_permutation(s.type(), s.lens(), perm); return shape::from_permutation(s.type(), s.lens(), perm);
} }
template <class R, class... Ts, class... Us> /**
R rocblas_invoke(R (*f)(Ts...), Us... xs) * Returns results of rocblas_status_success, rocblas_status_perf_degraded,
* or rocblas_status_invalid_value. Caller
* is expected to check for invalid index. Any other result causes an exception.
*
*/
template <class F, class Pack, class... Ts>
auto rocblas_invoke(F f, Pack p, Ts... xs)
{ {
if constexpr(sizeof...(Ts) == sizeof...(Us)) return p([=](auto... ws) {
return f(xs...); auto status = f(ws..., xs...);
else if(status != rocblas_status_success and status != rocblas_status_invalid_value)
return f(xs..., nullptr, nullptr); {
if(status == rocblas_status_perf_degraded)
{
std::cerr << "WARNING: degraded perf. in rocBLAS call" << std::endl;
}
else
MIGRAPHX_THROW("rocblas_invoke: rocBLAS call failed with status " +
std::to_string(status));
}
return status;
});
} }
static bool is_transposed(const shape& s) static bool is_transposed(const shape& s) { return s.transposed() and s.strides().back() != 1; }
{
if(not s.transposed())
return false;
return s.strides().back() != 1;
}
static rocblas_int get_batch_stride(const argument& a) static rocblas_int get_batch_stride(const shape& s)
{ {
return a.get_shape().strides()[a.get_shape().strides().size() - 3]; // This value is not needed for non-strided inputs
if(s.strides().size() < 3)
return 0;
else
return s.strides()[s.strides().size() - 3];
} }
template <class T> /**
void gemm_impl(context& ctx, * Wrapper for multiple rocBLAS calls. The constructor creates parameters for
const shape& output_shape, * these calls based on data shapes and other values contained in the associated
const std::vector<argument>& args, * instruction and operation.
T alpha, *
T beta, * The template parameter T is not the type of the matrix data but of the weighting
bool compute_fp32) * coefficients alpha and beta (these are float in rocBLAS internals)
*/
template <typename T>
struct gemm_impl
{ {
const bool is_3inputs = (args.size() == 4); gemm_impl(const shape& output_shape,
if(not is_3inputs) const std::vector<shape>& input_shapes,
{ T alpha_param,
beta = 0; T beta_param,
} bool compute_fp32_flag)
: alpha(alpha_param),
bool transa = is_transposed(args[0].get_shape()); beta(beta_param),
bool transb = is_transposed(args[1].get_shape()); is_3inputs(input_shapes.size() == 4),
auto n_dim = output_shape.lens().size(); compute_fp32(compute_fp32_flag)
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(compute_fp32)
{ {
if(arg_type == rocblas_datatype_f16_r) if(not is_3inputs)
compute_type = rocblas_datatype_f32_r; {
} beta = 0;
}
rocblas_gemm_flags flag = rocblas_gemm_flags_none; // Create lambdas that will cast alpha, beta to the output shape's type
auto a_lens = args[0].get_shape().lens(); // and retain the values being pointed to
auto b_lens = args[1].get_shape().lens(); output_shape.visit_type([&](auto as) {
output_shape.visit_type([&](auto as) { auto alpha_r = as(alpha);
auto alpha_r = as(alpha); auto beta_r = as(beta);
auto beta_r = as(beta); if(compute_fp32)
{
get_alpha = [=] { return &alpha; };
get_beta = [=] { return &beta; };
}
else
{
get_alpha = [=] { return &alpha_r; };
get_beta = [=] { return &beta_r; };
}
});
// use void pointer to select different data type if using fp32 mode transa = is_transposed(input_shapes[0]);
void* alpha_v = &alpha_r; transb = is_transposed(input_shapes[1]);
void* beta_v = &beta_r; auto n_dim = output_shape.lens().size();
auto dim_0 = n_dim - 2;
auto dim_1 = n_dim - 1;
// Leading dimensions of matrices
lda = input_shapes[0].strides()[transa ? dim_1 : dim_0];
ldb = input_shapes[1].strides()[transb ? dim_1 : dim_0];
ldc = input_shapes[2].strides()[dim_0];
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
compute_type = output_type;
if(compute_fp32) if(compute_fp32)
{ {
alpha_v = &alpha; if(arg_type == rocblas_datatype_f16_r)
beta_v = &beta; compute_type = rocblas_datatype_f32_r;
} }
auto out_lens = output_shape.lens(); auto a_lens = input_shapes[0].lens();
rocblas_int m = out_lens[dim_0]; auto b_lens = input_shapes[1].lens();
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
auto num_matrices = std::accumulate( auto out_lens = output_shape.lens();
m = out_lens[dim_0];
n = out_lens[dim_1];
k = input_shapes[0].lens()[dim_1];
a_stride = get_batch_stride(input_shapes[0]);
b_stride = get_batch_stride(input_shapes[1]);
c_stride = get_batch_stride(input_shapes[2]);
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0)) strided_batched = num_matrices > 1;
if(strided_batched and b_stride == 0 and input_shapes[0].standard())
{ {
// If the batch dimension of B is broadcasted, then we can // If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex // multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex. // instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices; m *= num_matrices;
strided_batched = false;
}
}
// the rocblas_gemm API handles inputs and output matrices as void run(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx = 0) const
// column-major format. When doing a C = A * B, we actually do {
// C^T = (B^T) * (A^T). That is the reason we input args[1] as if(strided_batched)
// A and args[0] as B in calling the rocblas_gemm. {
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
gemm_flags);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(), common_args,
transb ? rocblas_operation_transpose : rocblas_operation_none, rocblas_gemm_algo_solution_index,
transa ? rocblas_operation_transpose : rocblas_operation_none, solution_idx,
n, gemm_flags);
m, }
k, }
alpha_v,
to_pointer(args.at(1)), #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
arg_type, auto validate(context& ctx, const std::vector<shape>& input_shapes, int32_t solution_idx) const
ldb, {
to_pointer(args.at(0)), // Create dummy arguments for the shapes, and call the overloaded method
arg_type, std::vector<argument> input_args;
lda, std::transform(input_shapes.begin(),
beta_v, input_shapes.end(),
to_pointer(args[2]), std::back_inserter(input_args),
output_type, [](const shape& x) { return to_gpu(generate_argument(x)); });
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), return validate(ctx, input_args, solution_idx);
output_type, }
ldd,
compute_type, /**
rocblas_gemm_algo_standard, * Checks a particular solution for validity by running it with the flag
0, * rocblas_gemm_flags_check_solution_index (could be invalid if this model was
flag); * tuned with a different rocBLAS version)
*
* @return Returns either solution_idx if valid, or else the default value 0
* if not. The default does not mean list index 0, but tells the picker
* to choose a solution.
*/
int32_t
validate(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx) const
{
rocblas_status_ check_valid(rocblas_status_success);
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
} }
else else
{ {
auto a_stride = get_batch_stride(args[0]); auto common_args = create_gemm_ex_args_common(ctx, input_args);
auto b_stride = get_batch_stride(args[1]); check_valid = rocblas_invoke(&rocblas_gemm_ex,
auto c_stride = get_batch_stride(args[2]); common_args,
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride; rocblas_gemm_algo_solution_index,
rocblas_invoke(&rocblas_gemm_strided_batched_ex, solution_idx,
ctx.get_stream().get_rocblas(), rocblas_gemm_flags_check_solution_index);
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(args.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
});
if(check_valid == rocblas_status_invalid_value)
{
std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl;
return 0;
}
return solution_idx;
}
#endif
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "...strided_batched..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
*/
auto create_strided_batched_args_common(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
b_stride,
args[0].data(),
arg_type,
lda,
a_stride,
get_beta(),
args[2].data(),
output_type,
ldc,
c_stride,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
d_stride,
num_matrices,
compute_type);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
* */
auto create_gemm_ex_args_common(context& ctx, const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
get_alpha(),
args[1].data(),
arg_type,
ldb,
args[0].data(),
arg_type,
lda,
get_beta(),
args[2].data(),
output_type,
ldc,
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
compute_type);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* of the fastest one.
*/
int tune(context& ctx, const std::vector<shape>& input_shapes) const
{
// tuning meta parameters
const int hot_calls = 40;
std::vector<argument> input_args;
std::transform(input_shapes.begin(),
input_shapes.end(),
std::back_inserter(input_args),
[](const shape& x) { return to_gpu(generate_argument(x)); });
// Get the solutions list in 2 rocBLAS steps:
// 1. Find out how many solutions there are and allocate the array
// 2. Get the solutions
//
rocblas_int list_size = 0;
std::vector<rocblas_int> solution_indices;
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
&list_size);
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
nullptr,
&list_size);
solution_indices.resize(list_size);
auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
gemm_flags,
solution_indices.data(),
&list_size);
}
double best_time = std::numeric_limits<double>::max();
double first_time = -1;
// Initialize to default solution index
rocblas_int best_sol = 0;
for(auto sol : solution_indices)
{
// Warmup: the first call to an op. may not be representative since there is
// more time taken initializing caches, etc. so we won't time it.
run(ctx, input_args, sol);
double host_time = time<milliseconds>([&] {
for([[maybe_unused]] int hc : range(hot_calls))
run(ctx, input_args, sol);
ctx.finish();
});
host_time /= hot_calls;
// dev/evaluation only: track time for first solution.
if(first_time < 0)
first_time = host_time;
// track current best
if(host_time < best_time)
{
best_sol = sol;
best_time = host_time;
}
}
std::cout << "Winning GEMM solution: " << best_sol << " in " << best_time << " ms, beats "
<< first_time << "ms" << std::endl;
return best_sol;
}
#endif
private:
size_t num_matrices = 0;
rocblas_int m = 0;
rocblas_int n = 0;
rocblas_int k = 0;
bool transa = false;
bool transb = false;
T alpha = 0;
T beta = 0;
std::function<const void*()> get_alpha{};
std::function<const void*()> get_beta{};
rocblas_gemm_flags gemm_flags = rocblas_gemm_flags_none;
rocblas_int lda = 0;
rocblas_int ldb = 0;
rocblas_int ldc = 0;
rocblas_int ldd = 0;
rocblas_int a_stride = 0;
rocblas_int b_stride = 0;
rocblas_int c_stride = 0;
rocblas_int d_stride = 0;
rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_datatype arg_type = rocblas_datatype_f32_r;
rocblas_datatype output_type = rocblas_datatype_f32_r;
bool strided_batched = true;
bool is_3inputs = true;
bool compute_fp32 = true;
}; // gemm_impl
void gemm_compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx)
{
std::vector<shape> input_shapes;
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
} }
void gemm(context& ctx, void gemm_compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, int32_t alpha,
float beta, int32_t beta,
bool compute_fp32) bool compute_fp32,
int32_t solution_idx)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, compute_fp32); std::vector<shape> input_shapes;
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t gemm_finalize(context& ctx,
const shape& output_shape,
const std::vector<shape>& input_shapes,
float alpha,
float beta,
bool compute_fp32,
int32_t solution_idx)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
// This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
} }
void gemm(context& ctx, /**
const shape& output_shape, * Decides if the tune() or validate() method is appropriate and calls it.
const std::vector<argument>& args, * Return value is the chosen solution index, or 0 to let picker choose it.
int32_t alpha, */
int32_t beta, int32_t gemm_finalize(context& ctx,
bool compute_fp32) const shape& output_shape,
const std::vector<shape>& input_shapes,
int32_t alpha,
int32_t beta,
bool compute_fp32,
int32_t solution_idx)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, compute_fp32); #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(solution_idx == 0)
{
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.tune(ctx, input_shapes);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
#else
(void)ctx, (void)output_shape, (void)input_shapes;
(void)alpha, (void)beta, (void)compute_fp32;
#endif
return solution_idx;
} }
} // namespace gpu } // namespace gpu
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context; struct context;
void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch); shape transpose_batch(const shape& s, unsigned trans_batch);
void blas_shape(const shape& s);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
...@@ -52,6 +51,7 @@ struct rocblas_gemm ...@@ -52,6 +51,7 @@ struct rocblas_gemm
float beta = 0; float beta = 0;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0; unsigned trans_batch = 0;
int32_t solution_idx = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -60,7 +60,8 @@ struct rocblas_gemm ...@@ -60,7 +60,8 @@ struct rocblas_gemm
pack(f(self.alpha, "alpha"), pack(f(self.alpha, "alpha"),
f(self.beta, "beta"), f(self.beta, "beta"),
f(self.compute_fp32, "compute_fp32"), f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch"))); f(self.trans_batch, "trans_batch"),
f(self.solution_idx, "solution_idx")));
} }
std::string name() const std::string name() const
...@@ -76,6 +77,8 @@ struct rocblas_gemm ...@@ -76,6 +77,8 @@ struct rocblas_gemm
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
// When input shapes are A, B, C the GEMM equation is C  =  α AB+ β C where α, β are
// scalars
check_shapes{in_shapes, *this}.has(2, 3); check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]); blas_shape(inputs[0]);
blas_shape(inputs[1]); blas_shape(inputs[1]);
...@@ -111,11 +114,12 @@ struct rocblas_gemm ...@@ -111,11 +114,12 @@ struct rocblas_gemm
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
gemm(ctx, output_shape, args, alpha, beta, compute_fp32); gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
} }
else else
{ {
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32); gemm_compute(
ctx, output_shape, args, int32_t(alpha), int32_t(beta), compute_fp32, solution_idx);
} }
return args.back(); return args.back();
} }
...@@ -124,6 +128,33 @@ struct rocblas_gemm ...@@ -124,6 +128,33 @@ struct rocblas_gemm
{ {
return shapes.size() - 1; return shapes.size() - 1;
} }
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& input_shapes)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if(enabled(MIGRAPHX_ENABLE_GEMM_TUNING{}) or ctx.get_exhaustive_tune_flag())
{
if(this->name() == "gpu::gemm")
{
solution_idx = gemm_finalize(
ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx);
}
else
{
solution_idx = gemm_finalize(ctx,
output_shape,
input_shapes,
int32_t(alpha),
int32_t(beta),
compute_fp32,
solution_idx);
}
}
#else
// suppress compiler warnings
(void)ctx, (void)output_shape, (void)input_shapes;
#endif
}
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment