Commit f8a75f8a authored by Paul's avatar Paul
Browse files

Merge

parents 74448ed6 d00fdf6e
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_max : scatternd_op<scatternd_max>
{
scatternd_max() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = std::max(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_min : scatternd_op<scatternd_min>
{
scatternd_min() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = std::min(x, y); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived> ...@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived>
auto k = indices_shape.lens().back(); auto k = indices_shape.lens().back();
auto q = indices_shape.ndim(); auto q = indices_shape.ndim();
auto r = dyn_out.computed_shape.ndim(); auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) { for(auto i = 0u; i < updates_shape.elements(); ++i)
{
auto updates_idx = updates_std.multi(i); auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0); std::vector<std::size_t> indices_idx(q, 0);
std::copy( std::copy(
...@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived> ...@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived>
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]); self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
}); }
}); });
}); });
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <array>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -84,10 +85,10 @@ struct unary : op_name<Derived> ...@@ -84,10 +85,10 @@ struct unary : op_name<Derived>
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
std::transform(input.begin(), par_transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
}); });
}); });
return result; return result;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tune_axis.hpp>
#include <utility>
#include <map>
#include <limits>
#include <optional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// https://onnx.ai/onnx/operators/onnx__Unique.html
// The Onnx spec refers to numpy specification, used as a reference:
// https://numpy.org/doc/stable/reference/generated/numpy.unique.html
// Input : Given an array of elements : X.
// Output(s) :
// 1. Find the unique elements (Y) of input (X).
//
// There are three outputs in addition to the unique elements in Y:
// 2. the indices of the input array that give the unique values
// 3. the indices of the unique array that reconstruct the input array
// 4. the number of times each unique value comes up in the input array
// Optional Attribute: 'Sorted' = 1 for sorted; = 0 for unsorted.
// Onnx specification makes 'sorted' a default, while Numpy always sorts.
//
// Optional Attribute: 'Axis' is 'None' (default) or a valid int < rank(X).
// Negative values are allowed.
//
// Numpy has the following important note on Axis:
// ------------------------------------------------------------------
// When an axis is specified the subarrays indexed by the axis are
// sorted. This is done by making the specified axis the first
// dimension of the array (move the axis to the first dimension to
// keep the order of the other axes) and then flattening the subarrays
// in C order. The flattened subarrays are then viewed as a structured
// type with each element given a label, with the effect that we end
// up with a 1-D array of structured types that can be treated in the
// same way as any other 1-D array. The result is that the flattened
// subarrays are sorted in lexicographic order starting with the first
// element.
// ------------------------------------------------------------------
struct unique
{
template <class T>
auto make_idx_less_fn(const T& data, size_t chunk_sz) const
{
return [&data, chunk_sz](auto idx1, auto idx2) {
return std::lexicographical_compare(data.begin() + idx1,
data.begin() + idx1 + chunk_sz,
data.begin() + idx2,
data.begin() + idx2 + chunk_sz);
};
}
// CASE SORTED:
//
// To process into a sorted unique series of elements/chunks:
// Chunk size == 1 means a simple element; >1 means a flat representation.
// Steps: first go through the input elements/chunks for uniqueness.
// At the end of this processing, per the sorted sequence of unique elements:
// update/create data structures: y, y_indices, x_rev_indices, y_count
//
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// OUTPUT(s): indices..
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [1, 2, 3, 4] --- the unique output is constructed from x[y_indices[...]]
template <class T>
auto sorted_uniq_indices(const T& input_data, size_t chunk_sz) const
{
struct y_info
{
size_t y_idx;
size_t x_idx;
size_t ct = 0;
};
auto idx_less_fn = make_idx_less_fn(input_data, chunk_sz);
std::map<size_t, y_info, decltype(idx_less_fn)> uniq_val_map(idx_less_fn);
std::tuple<std::vector<std::size_t>, std::vector<std::size_t>, std::vector<std::size_t>> rv;
auto& [y_indices, x_rev_indices, y_count] = rv;
// go through all the elements and find the unique elements..
size_t count_x = input_data.size();
for(size_t f_idx = 0, x_idx = 0; f_idx < count_x; f_idx += chunk_sz, x_idx++)
{
y_info entry = {.y_idx = uniq_val_map.size(), .x_idx = x_idx};
auto [itr, added_new] = uniq_val_map.insert({f_idx, entry});
itr->second.ct++;
x_rev_indices.push_back(itr->second.y_idx);
}
std::vector<std::size_t> y2x_indices(uniq_val_map.size());
y_indices.resize(uniq_val_map.size());
y_count.resize(uniq_val_map.size());
size_t idx = 0;
// the unique elements are now sorted:
// post-processing for all the return indices.
for(const auto& v : uniq_val_map)
{
y2x_indices[v.second.y_idx] = idx;
y_indices[idx] = v.second.x_idx;
y_count[idx] = v.second.ct;
idx++;
}
// update x_rev_indices as per the sorted order of y_indices
for(auto& i : x_rev_indices)
i = y2x_indices[i];
return rv;
}
// CASE UNSORTED:
//
// To process into an un-sorted unique series of elements/chunks:
// For chunk size = 1 is a simple element, else use a flat representation of a tensor obj
// Go through the input elements/chunks one by one with inline processing of indices..
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// OUTPUT(s): indices..
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [2, 1, 3, 4] --- the unique output is constructed from x[y_indices[...]]
// Output data structures: y_indices, x_rev_indices, y_count are processed inline.
template <class T>
auto unsorted_uniq_indices(const T& input_data, size_t chunk_sz) const
{
auto idx_less_fn = make_idx_less_fn(input_data, chunk_sz);
std::map<size_t, size_t, decltype(idx_less_fn)> uniq_val_map(idx_less_fn);
// rv is used for NVRO below..
std::tuple<std::vector<std::size_t>, std::vector<std::size_t>, std::vector<std::size_t>> rv;
auto& [y_indices, x_rev_indices, y_count] = rv;
// go through all the elements and add the unique elements into the map..
// inline processing for outputs: y_indices, x_rev_indices, y_count
size_t count_x = input_data.size();
for(size_t f_idx = 0; f_idx < count_x; f_idx += chunk_sz)
{
auto [itr, added_new] = uniq_val_map.insert({f_idx, y_indices.size()});
if(added_new)
{
y_count.push_back(0);
y_indices.push_back(x_rev_indices.size());
}
y_count[itr->second]++;
x_rev_indices.push_back(itr->second);
}
return rv;
}
// Axis. Default: none. Range: [-rank, rank-1]
std::optional<int64_t> axis;
// Sorted, Default: 1= sorted. 0 = unsorted.
bool sorted = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.sorted, "sorted"));
}
std::string name() const { return "unique"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto& sh_x = inputs[0];
auto lens_x = sh_x.lens();
size_t dim_x = sh_x.ndim();
size_t max_uniq_ct = sh_x.elements();
std::vector<shape::dynamic_dimension> d_out;
if(axis)
{
int64_t t_axis = migraphx::tune_axis(dim_x, *axis, name());
if(t_axis != 0)
MIGRAPHX_THROW("Unique: Only supports axis = 0 or None");
d_out = sh_x.to_dynamic().dyn_dims();
// only axis = 0 is supported:
max_uniq_ct = lens_x[0];
// min = 1 unique element; max = full dimension along axis 0
d_out[0] = {1, max_uniq_ct};
}
else
{
d_out.push_back({1, max_uniq_ct});
}
shape sh_y = {sh_x.type(), d_out};
// The three outputted Indices are just 1-D:
shape sh_idx{shape::int64_type, {d_out[0]}};
return {{sh_y, sh_idx, sh_idx, sh_idx}};
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
auto sh_x = args.front().get_shape();
auto lens_x = sh_x.lens();
shape output_shape = dyn_out.computed_shape;
auto vec_ss = output_shape.sub_shapes();
auto ct_x = sh_x.elements();
shape sh_y = {vec_ss[0].type(), {ct_x}};
shape sh_idx = {vec_ss[1].type(), {ct_x}};
shape sh_x_idx = {vec_ss[1].type(), {ct_x}};
argument res_y{sh_y};
argument res_y_idx{sh_idx};
argument res_x_rev_idx{sh_idx};
argument res_y_ct_idx{sh_idx};
std::vector<size_t> out_y_idx;
std::vector<size_t> out_x_rev_idx;
std::vector<size_t> out_y_ct;
// If axis is not none, for >1D tensors, we have to consider
// then, the uniqueness of chunks of sub-tensors: a subsequence of built-ins..
// For a built-in type, chunk_sz is of course = 1
size_t chunk_sz = 1;
if(axis)
chunk_sz = ct_x / lens_x[0]; // axis = 0 is supported.
visit_all(args.front(), res_y)([&](auto x, auto y_flat) {
using o_type = typename decltype(x)::value_type;
std::vector<o_type> x_in(x.begin(), x.end());
std::tie(out_y_idx, out_x_rev_idx, out_y_ct) =
sorted ? sorted_uniq_indices(x_in, chunk_sz)
: unsorted_uniq_indices(x_in, chunk_sz);
const auto uniq_ct = out_y_idx.size();
// construct y from x[indices] in flattened form
// later we reshape y to the final shape..
auto y_dst = y_flat.begin();
for(size_t idx = 0; idx < uniq_ct; idx++)
y_dst = copy_n(x_in.begin() + out_y_idx[idx] * chunk_sz, chunk_sz, y_dst);
std::vector<size_t> lens_y;
// if axis is specified:
// the output shape keeps the n-1 dimensions of x
if(axis)
{
lens_y = lens_x;
lens_y[0] = uniq_ct;
}
else
{
lens_y = {uniq_ct};
}
sh_y = {sh_y.type(), lens_y};
sh_idx = {sh_idx.type(), {uniq_ct}};
});
visit_all(res_y_idx, res_x_rev_idx, res_y_ct_idx)(
[&](auto y_indices, auto x_rev_indices, auto y_count) {
std::copy(out_y_idx.begin(), out_y_idx.end(), y_indices.begin());
std::copy(out_x_rev_idx.begin(), out_x_rev_idx.end(), x_rev_indices.begin());
std::copy(out_y_ct.begin(), out_y_ct.end(), y_count.begin());
sh_x_idx = {sh_idx.type(), {out_x_rev_idx.size()}};
});
return {{res_y.reshape(sh_y),
res_y_idx.reshape(sh_idx),
res_x_rev_idx.reshape(sh_x_idx),
res_y_ct_idx.reshape(sh_idx)}};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -119,6 +119,8 @@ ...@@ -119,6 +119,8 @@
#include <migraphx/op/scatternd_add.hpp> #include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp> #include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp> #include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/scatternd_max.hpp>
#include <migraphx/op/scatternd_min.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
...@@ -137,6 +139,7 @@ ...@@ -137,6 +139,7 @@
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp> #include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp> #include <migraphx/op/undefined.hpp>
#include <migraphx/op/unique.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp> #include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp> #include <migraphx/op/where.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/config.hpp>
#if MIGRAPHX_HAS_EXECUTORS
#include <execution>
#else
#include <migraphx/simple_par_for.hpp>
#endif
#include <algorithm>
#include <mutex>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
struct exception_list
{
std::vector<std::exception_ptr> exceptions;
std::mutex m;
void add_exception()
{
std::lock_guard<std::mutex> guard(m);
exceptions.push_back(std::current_exception());
}
template <class F>
auto collect(F f)
{
return [f, this](auto&&... xs) {
try
{
f(std::forward<decltype(xs)>(xs)...);
}
catch(...)
{
this->add_exception();
}
};
}
void throw_if_exception() const
{
if(not exceptions.empty())
std::rethrow_exception(exceptions.front());
}
};
} // namespace detail
template <class InputIt, class OutputIt, class UnaryOperation>
OutputIt par_transform(InputIt first1, InputIt last1, OutputIt d_first, UnaryOperation unary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(std::execution::par, first1, last1, d_first, std::move(unary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = unary_op(first1[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt1, class InputIt2, class OutputIt, class BinaryOperation>
OutputIt par_transform(
InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt d_first, BinaryOperation binary_op)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::transform(
std::execution::par, first1, last1, first2, d_first, std::move(binary_op));
#else
simple_par_for(last1 - first1, [&](auto i) { d_first[i] = binary_op(first1[i], first2[i]); });
return d_first + (last1 - first1);
#endif
}
template <class InputIt, class UnaryFunction>
void par_for_each(InputIt first, InputIt last, UnaryFunction f)
{
#if MIGRAPHX_HAS_EXECUTORS
// Propagate the exception
detail::exception_list ex;
std::for_each(std::execution::par, first, last, ex.collect(std::move(f)));
ex.throw_if_exception();
#else
simple_par_for(last - first, [&](auto i) { f(first[i]); });
#endif
}
template <class... Ts>
auto par_copy_if(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::copy_if(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::copy_if(std::forward<Ts>(xs)...);
#endif
}
template <class... Ts>
auto par_sort(Ts&&... xs)
{
#if MIGRAPHX_HAS_EXECUTORS
return std::sort(std::execution::par, std::forward<Ts>(xs)...);
#else
return std::sort(std::forward<Ts>(xs)...);
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
...@@ -24,93 +24,23 @@ ...@@ -24,93 +24,23 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP #define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread> #include <migraphx/par.hpp>
#include <cmath> #include <migraphx/ranges.hpp>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F> template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f) void par_for(std::size_t n, F f)
{ {
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(), using iterator = basic_iota_iterator<id, std::size_t>;
n / std::max<std::size_t>(1, min_grain)); par_for_each(iterator{0, {}}, iterator{n, {}}, f);
par_for_impl(n, threadsize, f);
} }
template <class F> template <class F>
void par_for(std::size_t n, F f) void par_for(std::size_t n, std::size_t, F f)
{ {
const int min_grain = 8; par_for(n, f);
par_for(n, min_grain, f);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <string> #include <string>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape ...@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
m(int32_type, int32_t) \ m(int32_type, int32_t) \
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on // clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct joinable_thread : std::thread
{
template <class... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...) // NOLINT
{
}
joinable_thread& operator=(joinable_thread&& other) = default;
joinable_thread(joinable_thread&& other) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void simple_par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
thread_invoke(i, 0, f);
}
else
{
std::vector<joinable_thread> threads(threadsize);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
}
}
template <class F>
void simple_par_for(std::size_t n, std::size_t min_grain, F f)
{
const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
n / std::max<std::size_t>(1, min_grain));
simple_par_for_impl(n, threadsize, f);
}
template <class F>
void simple_par_for(std::size_t n, F f)
{
const int min_grain = 8;
simple_par_for(n, min_grain, f);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -24,21 +24,21 @@ ...@@ -24,21 +24,21 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP #define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#include <utility>
#include <cstdint>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR") inline int tune_axis(int n_dim, int axis, const std::string& op_name = "OPERATOR")
{ {
if(axis >= n_dim or std::abs(axis) > n_dim) if(axis < 0)
{ axis += n_dim;
if(axis < 0 or axis >= n_dim)
MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range."); MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range.");
}
return (axis < 0) ? axis + n_dim : axis; return axis;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -28,25 +28,35 @@ ...@@ -28,25 +28,35 @@
#include <type_traits> #include <type_traits>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \
template <class X> \
struct trait : std::trait<X> \
{ \
};
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ #define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \ template <> \
struct trait<T> : std::true_type \ struct trait<T> : std::true_type \
{ \ { \
}; };
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic);
MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
template <class T> template <class T>
using accumulator_type = using accumulator_type =
std::conditional_t<is_floating_point<T>{}, std::conditional_t<is_floating_point<T>{},
......
...@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED) ...@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED)
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto) protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto)
add_library(onnx-proto STATIC ${PROTO_SRCS}) add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR}) target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(onnx-proto PRIVATE -w) if(MSVC)
target_compile_options(onnx-proto PRIVATE /w)
else()
target_compile_options(onnx-proto PRIVATE -w)
endif()
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
...@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) ...@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
migraphx_generate_export_header(migraphx_onnx) migraphx_generate_export_header(migraphx_onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL") target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
if(NOT WIN32)
target_link_libraries(migraphx_onnx PRIVATE "-Wl,--exclude-libs,ALL")
endif()
target_link_libraries(migraphx_onnx PUBLIC migraphx) target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets( rocm_install_targets(
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#include <migraphx/config.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values);
instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
// //
// Copyright (c) ONNX Project Contributors. // SPDX-License-Identifier: Apache-2.0
// Licensed under the MIT license.
syntax = "proto2"; syntax = "proto2";
...@@ -20,23 +20,16 @@ package onnx_for_migraphx; ...@@ -20,23 +20,16 @@ package onnx_for_migraphx;
// //
// This document describes the syntax of models and their computation graphs, // This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX // as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short. // Intermediate Representation, or 'IR' for short.
// //
// The normative semantic specification of the ONNX IR is found in docs/IR.md. // The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md. // Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes // Notes
// //
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
//
// Protobuf compatibility // Protobuf compatibility
// //
// To simplify framework compatibility, ONNX is defined using the subset of protobuf // To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any // that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions. // protobuf features that are only available in one of the two versions.
// //
...@@ -60,7 +53,7 @@ enum Version { ...@@ -60,7 +53,7 @@ enum Version {
_START_VERSION = 0; _START_VERSION = 0;
// The version field is always serialized and we will use it to store the // The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version // version that the graph is generated from. This helps us set up version
// control. // control.
// For the IR, we are using simple numbers starting with 0x00000001, // For the IR, we are using simple numbers starting with 0x00000001,
// which was the version we published on Oct 10, 2017. // which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001; IR_VERSION_2017_10_10 = 0x0000000000000001;
...@@ -92,15 +85,28 @@ enum Version { ...@@ -92,15 +85,28 @@ enum Version {
// - Add sparse initializers // - Add sparse initializers
IR_VERSION_2019_9_19 = 0x0000000000000006; IR_VERSION_2019_9_19 = 0x0000000000000006;
// IR VERSION 7 published on <TBD> // IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and // - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the // mutable variables. Global variables are visible in all graphs of the
// stored models. // stored models.
// - Add message TrainingInfoProto to store initialization // - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto // method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables. // can modify the values of mutable variables.
// - Make inference graph callable from TrainingInfoProto via GraphCall operator. // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION = 0x0000000000000007; IR_VERSION_2020_5_8 = 0x0000000000000007;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30 = 0x0000000000000008;
// IR VERSION 9 published on TBD
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION = 0x0000000000000009;
} }
// Attributes // Attributes
...@@ -121,6 +127,7 @@ message AttributeProto { ...@@ -121,6 +127,7 @@ message AttributeProto {
TENSOR = 4; TENSOR = 4;
GRAPH = 5; GRAPH = 5;
SPARSE_TENSOR = 11; SPARSE_TENSOR = 11;
TYPE_PROTO = 13;
FLOATS = 6; FLOATS = 6;
INTS = 7; INTS = 7;
...@@ -128,11 +135,12 @@ message AttributeProto { ...@@ -128,11 +135,12 @@ message AttributeProto {
TENSORS = 9; TENSORS = 9;
GRAPHS = 10; GRAPHS = 10;
SPARSE_TENSORS = 12; SPARSE_TENSORS = 12;
TYPE_PROTOS = 14;
} }
// The name field MUST be present for this version of the IR. // The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute optional string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute // In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope. // in parent scope.
...@@ -159,6 +167,7 @@ message AttributeProto { ...@@ -159,6 +167,7 @@ message AttributeProto {
optional SparseTensorProto sparse_tensor = 22; // sparse tensor value optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated. // Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph // optional ValueProto v = 12; // value - subsumes everything but graph
optional TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints repeated int64 ints = 8; // list of ints
...@@ -166,6 +175,7 @@ message AttributeProto { ...@@ -166,6 +175,7 @@ message AttributeProto {
repeated TensorProto tensors = 10; // list of tensors repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph repeated GraphProto graphs = 11; // list of graph
repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
repeated TypeProto type_protos = 15;// list of type protos
} }
// Defines information on value, including the name, the type, and // Defines information on value, including the name, the type, and
...@@ -185,7 +195,7 @@ message ValueInfoProto { ...@@ -185,7 +195,7 @@ message ValueInfoProto {
// Computation graphs are made up of a DAG of nodes, which represent what is // Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks. // commonly called a "layer" or "pipeline stage" in machine learning frameworks.
// //
// For example, it can be a node of type "Conv" that takes in an image, a filter // For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output. // tensor and a bias tensor, and produces the convolved output.
message NodeProto { message NodeProto {
repeated string input = 1; // namespace Value repeated string input = 1; // namespace Value
...@@ -211,7 +221,7 @@ message NodeProto { ...@@ -211,7 +221,7 @@ message NodeProto {
// TrainingInfoProto stores information for training a model. // TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step // In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model // and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been consumed. // back to its original state as if no training has been performed.
// Training algorithm improves the model based on input data. // Training algorithm improves the model based on input data.
// //
// The semantics of the initialization-step is that the initializers // The semantics of the initialization-step is that the initializers
...@@ -224,8 +234,8 @@ message NodeProto { ...@@ -224,8 +234,8 @@ message NodeProto {
// training algorithm's step. After the execution of a // training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding" // TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains // may be immediately updated. If the targeted training algorithm contains
// consecutive update stages (such as block coordinate descent methods), // consecutive update steps (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each stage. // the user needs to create a TrainingInfoProto for each step.
message TrainingInfoProto { message TrainingInfoProto {
// This field describes a graph to compute the initial tensors // This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input // upon starting the training process. Initialization graph has no input
...@@ -239,24 +249,42 @@ message TrainingInfoProto { ...@@ -239,24 +249,42 @@ message TrainingInfoProto {
// iteration to zero. // iteration to zero.
// //
// By default, this field is an empty graph and its evaluation does not // By default, this field is an empty graph and its evaluation does not
// produce any output. // produce any output. Thus, no initializer would be changed by default.
optional GraphProto initialization = 1; optional GraphProto initialization = 1;
// This field represents a training algorithm step. Given required inputs, // This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's // it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this graph contains loss node, gradient node, // initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count, and some calls to the inference // optimizer node, increment of iteration count.
// graph.
// //
// The field algorithm.node is the only place the user can use GraphCall // An execution of the training algorithm step is performed by executing the
// operator. The only callable graph is the one stored in ModelProto.graph. // graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
// //
// By default, this field is an empty graph and its evaluation does not // By default, this field is an empty graph and its evaluation does not
// produce any output. // produce any output. Evaluating the default training step never
// update any initializers.
optional GraphProto algorithm = 2; optional GraphProto algorithm = 2;
// This field specifies the bindings from the outputs of "initialization" to // This field specifies the bindings from the outputs of "initialization" to
// some initializers in "ModelProto.graph.initializer" and // some initializers in "ModelProto.graph.initializer" and
// the "algorithm.initializer" in the same TrainingInfoProto. // the "algorithm.initializer" in the same TrainingInfoProto.
// See "update_binding" below for details. // See "update_binding" below for details.
// //
...@@ -284,23 +312,16 @@ message TrainingInfoProto { ...@@ -284,23 +312,16 @@ message TrainingInfoProto {
// be multiple key-value pairs in "update_binding". // be multiple key-value pairs in "update_binding".
// //
// The initializers appears as keys in "update_binding" are considered // The initializers appears as keys in "update_binding" are considered
// mutable and globally-visible variables. This implies some behaviors // mutable variables. This implies some behaviors
// as described below. // as described below.
// //
// 1. We have only unique keys in all "update_binding"s so that two global // 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one // variables may not have the same name. This ensures that one
// global variable is assigned up to once. // variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or // 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer". // "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm". // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. If an optional input of a graph is omitted when using GraphCall, the // 4. Mutable variables are initialized to the value specified by the
// global variable with the same name may be used.
// 5. When using GraphCall, the users always can pass values to optional
// inputs of the called graph even if the associated initializers appears
// as keys in "update_binding"s.
// 6. The graphs in TrainingInfoProto's can use global variables as
// their operator inputs.
// 7. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by // corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
// //
...@@ -375,13 +396,31 @@ message ModelProto { ...@@ -375,13 +396,31 @@ message ModelProto {
// //
// If this field is empty, the training behavior of the model is undefined. // If this field is empty, the training behavior of the model is undefined.
repeated TrainingInfoProto training_info = 20; repeated TrainingInfoProto training_info = 20;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated FunctionProto functions = 25;
}; };
// StringStringEntryProto follows the pattern for cross-proto-version maps. // StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps // See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto { message StringStringEntryProto {
optional string key = 1; optional string key = 1;
optional string value= 2; optional string value = 2;
}; };
message TensorAnnotation { message TensorAnnotation {
...@@ -397,7 +436,7 @@ message TensorAnnotation { ...@@ -397,7 +436,7 @@ message TensorAnnotation {
// Graphs // Graphs
// //
// A graph defines the computational logic of a model and is comprised of a parameterized // A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs. // list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning // This is the equivalent of the "network" or "graph" in many deep learning
// frameworks. // frameworks.
...@@ -409,8 +448,9 @@ message GraphProto { ...@@ -409,8 +448,9 @@ message GraphProto {
optional string name = 2; // namespace Graph optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph. // A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// MAY also appear in the input list. // The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated TensorProto initializer = 5; repeated TensorProto initializer = 5;
// Initializers (see above) stored in sparse format. // Initializers (see above) stored in sparse format.
...@@ -433,13 +473,8 @@ message GraphProto { ...@@ -433,13 +473,8 @@ message GraphProto {
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14; repeated TensorAnnotation quantization_annotation = 14;
// DO NOT USE the following fields, they were deprecated from earlier versions. reserved 3, 4, 6 to 9;
// repeated string input = 3; reserved "ir_version", "producer_version", "producer_tag", "domain";
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
} }
// Tensors // Tensors
...@@ -474,6 +509,17 @@ message TensorProto { ...@@ -474,6 +509,17 @@ message TensorProto {
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16; BFLOAT16 = 16;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here. // Future extensions go here.
} }
...@@ -507,11 +553,11 @@ message TensorProto { ...@@ -507,11 +553,11 @@ message TensorProto {
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64. // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true]; repeated float float_data = 4 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16 values // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
// float16 values must be bit-wise converted to an uint16_t prior // float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer. // to writing to the buffer.
// When this field is present, the data_type field MUST be // When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true]; repeated int32 int32_data = 5 [packed = true];
// For strings. // For strings.
...@@ -589,6 +635,8 @@ message TensorProto { ...@@ -589,6 +635,8 @@ message TensorProto {
message SparseTensorProto { message SparseTensorProto {
// The sequence of non-default values are encoded as a tensor of shape [NNZ]. // The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors. // The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
optional TensorProto values = 1; optional TensorProto values = 1;
// The indices of the non-default values, which may be stored in one of two formats. // The indices of the non-default values, which may be stored in one of two formats.
...@@ -619,7 +667,7 @@ message TensorShapeProto { ...@@ -619,7 +667,7 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor // Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure // dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor. // that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations. // for pre-defined dimension denotations.
optional string denotation = 3; optional string denotation = 3;
}; };
...@@ -656,6 +704,23 @@ message TypeProto { ...@@ -656,6 +704,23 @@ message TypeProto {
optional TypeProto value_type = 2; optional TypeProto value_type = 2;
}; };
// wrapper for Tensor, Sequence, or Map
message Optional {
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
optional TypeProto elem_type = 1;
};
message SparseTensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
oneof value { oneof value {
// The type of a tensor. // The type of a tensor.
...@@ -672,11 +737,18 @@ message TypeProto { ...@@ -672,11 +737,18 @@ message TypeProto {
// The type of a map. // The type of a map.
Map map_type = 5; Map map_type = 5;
// The type of an optional.
Optional optional_type = 9;
// Type of the sparse tensor
SparseTensor sparse_tensor_type = 8;
} }
// An optional denotation can be used to denote the whole // An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is // type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations. // for pre-defined type denotations.
optional string denotation = 6; optional string denotation = 6;
} }
...@@ -696,7 +768,67 @@ message OperatorSetIdProto { ...@@ -696,7 +768,67 @@ message OperatorSetIdProto {
optional int64 version = 2; optional int64 version = 2;
} }
// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
optional string name = 1;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved 2;
reserved "since_version";
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved 3;
reserved "status";
// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated string attribute = 6;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated AttributeProto attribute_proto = 11;
// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
optional string doc_string = 8;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
// For using protobuf-lite repeated OperatorSetIdProto opset_import = 9;
option optimize_for = LITE_RUNTIME;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional string domain = 10;
}
// For using protobuf-lite
option optimize_for = LITE_RUNTIME;
\ No newline at end of file
...@@ -34,7 +34,9 @@ ...@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case onnx::AttributeProto::TENSORS: case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::SPARSE_TENSOR: case onnx::AttributeProto::SPARSE_TENSOR:
case onnx::AttributeProto::SPARSE_TENSORS: case onnx::AttributeProto::SPARSE_TENSORS:
case onnx::AttributeProto::TYPE_PROTOS:
case onnx::AttributeProto::TYPE_PROTO:
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type())); MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
...@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data()); return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT8E4M3FNUZ: {
std::vector<int32_t> data_int32(t.int32_data().begin(), t.int32_data().end());
std::vector<migraphx::fp8::fp8e4m3fnuz> data_fp8;
std::transform(data_int32.begin(),
data_int32.end(),
std::back_inserter(data_fp8),
[](float raw_val) { return migraphx::fp8::fp8e4m3fnuz{raw_val}; });
return create_literal(shape::fp8e4m3fnuz_type, dims, data_fp8);
}
case onnx::TensorProto::FLOAT8E5M2FNUZ:
case onnx::TensorProto::FLOAT8E5M2:
case onnx::TensorProto::FLOAT8E4M3FN:
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::STRING: case onnx::TensorProto::STRING:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
...@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype) ...@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type;
case 14:
case 15:
case 16:
case 17:
case 19:
case 20:
default: { default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
} }
......
...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv ...@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
} }
} }
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm> struct parse_lstm : op_parser<parse_lstm>
{ {
std::vector<op_desc> operators() const { return {{"LSTM"}}; } std::vector<op_desc> operators() const { return {{"LSTM"}}; }
...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>(); input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
} }
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 8) if(args.size() < 8)
{ {
...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins); args.insert(args.end(), 8 - args.size(), ins);
} }
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm", auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size}, {{"hidden_size", hidden_size},
...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm> ...@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output = auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
}; };
......
...@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
// use literal. The array populated by random_uniform may have any shape, as long its // use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size . // number of elements is batch_size * sample_size .
size_t batch_size = s0.lens().front(); size_t batch_size = s0.lens().front();
auto rand_dummy = info.add_literal( auto rand_dummy = info.add_literal(migraphx::literal{
migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}}); migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
randoms = randoms =
info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy); info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment