Unverified Commit 0ffcccbc authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into jit-reduce-reg

parents 4f12db9e 2b5c5f5e
...@@ -9,6 +9,8 @@ CheckOptions: ...@@ -9,6 +9,8 @@ CheckOptions:
value: risky value: risky
- key: modernize-loop-convert.NamingStyle - key: modernize-loop-convert.NamingStyle
value: lower_case value: lower_case
- key: misc-const-correctness.AnalyzeValues
value: 'false'
- key: performance-unnecessary-copy-initialization.AllowedTypes - key: performance-unnecessary-copy-initialization.AllowedTypes
value: 'shape' value: 'shape'
- key: performance-unnecessary-value-param.AllowedTypes - key: performance-unnecessary-value-param.AllowedTypes
......
...@@ -32,7 +32,8 @@ jobs: ...@@ -32,7 +32,8 @@ jobs:
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
# the cache is created without them in the post run. # the cache is created without them in the post run.
# It also restores the cache if it exists. # It also restores the cache if it exists.
- uses: satackey/action-docker-layer-caching@v0.0.11 # name: Docker Layer Caching2
- uses: jpribyl/action-docker-layer-caching@v0.1.1
# Ignore the failure of a step and avoid terminating the job. # Ignore the failure of a step and avoid terminating the job.
continue-on-error: true continue-on-error: true
...@@ -81,7 +82,7 @@ jobs: ...@@ -81,7 +82,7 @@ jobs:
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
# the cache is created without them in the post run. # the cache is created without them in the post run.
# It also restores the cache if it exists. # It also restores the cache if it exists.
- uses: satackey/action-docker-layer-caching@v0.0.11 - uses: jpribyl/action-docker-layer-caching@v0.1.1
# Ignore the failure of a step and avoid terminating the job. # Ignore the failure of a step and avoid terminating the job.
continue-on-error: true continue-on-error: true
...@@ -126,7 +127,7 @@ jobs: ...@@ -126,7 +127,7 @@ jobs:
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
# the cache is created without them in the post run. # the cache is created without them in the post run.
# It also restores the cache if it exists. # It also restores the cache if it exists.
- uses: satackey/action-docker-layer-caching@v0.0.11 - uses: jpribyl/action-docker-layer-caching@v0.1.1
# Ignore the failure of a step and avoid terminating the job. # Ignore the failure of a step and avoid terminating the job.
continue-on-error: true continue-on-error: true
......
...@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && ...@@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.3/ ubuntu main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.4.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
......
...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local ...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.3/ ubuntu main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.4.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
......
...@@ -58,6 +58,7 @@ add_library(migraphx ...@@ -58,6 +58,7 @@ add_library(migraphx
layout_nhwc.cpp layout_nhwc.cpp
load_save.cpp load_save.cpp
make_op.cpp make_op.cpp
memory_coloring.cpp
module.cpp module.cpp
msgpack.cpp msgpack.cpp
normalize_attributes.cpp normalize_attributes.cpp
...@@ -65,8 +66,6 @@ add_library(migraphx ...@@ -65,8 +66,6 @@ add_library(migraphx
op_enums.cpp op_enums.cpp
operation.cpp operation.cpp
optimize_module.cpp optimize_module.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
pad_calc.cpp pad_calc.cpp
pass_manager.cpp pass_manager.cpp
permutation.cpp permutation.cpp
......
...@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type; ...@@ -58,12 +58,12 @@ using deduce = typename detail::deduce<T>::type;
namespace std { namespace std {
template <class T> template <class T>
struct common_type<migraphx::half, T> : std::common_type<float, T> struct common_type<migraphx::half, T> : std::common_type<float, T> // NOLINT
{ {
}; };
template <class T> template <class T>
struct common_type<T, migraphx::half> : std::common_type<float, T> struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{ {
}; };
......
...@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept; ...@@ -41,7 +41,7 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
namespace std { namespace std {
template <> template <>
struct hash<migraphx::instruction_ref> struct hash<migraphx::instruction_ref> // NOLINT
{ {
using argument_type = migraphx::instruction_ref; using argument_type = migraphx::instruction_ref;
using result_type = std::size_t; using result_type = std::size_t;
...@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref> ...@@ -52,7 +52,7 @@ struct hash<migraphx::instruction_ref>
}; };
template <> template <>
struct equal_to<migraphx::instruction_ref> struct equal_to<migraphx::instruction_ref> // NOLINT
{ {
using argument_type = migraphx::instruction_ref; using argument_type = migraphx::instruction_ref;
using result_type = bool; using result_type = bool;
......
...@@ -39,7 +39,7 @@ struct memory_coloring ...@@ -39,7 +39,7 @@ struct memory_coloring
{ {
std::string allocation_op{}; std::string allocation_op{};
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory_coloring"; }
void apply(module& m) const; void apply(module& m) const;
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP #define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -47,33 +48,103 @@ struct gathernd ...@@ -47,33 +48,103 @@ struct gathernd
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this, true}.has(2);
auto r = inputs.front().lens().size(); auto i_shape = inputs.back();
auto q = inputs.back().lens().size(); auto data_shape = inputs.front();
auto k = inputs.back().lens().back(); auto r = data_shape.ndim();
auto q = i_shape.ndim();
size_t k;
if(i_shape.dynamic())
{
// the rank of the output is a function of k, so it must be fixed.
if(not i_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = i_shape.dyn_dims().back().min;
}
else
k = i_shape.lens().back();
// Begin input validation checks.
int output_ndim = int(q) + r - k - batch_dims - 1;
if(k > r - batch_dims) if(k > r - batch_dims)
{ {
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) + MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " + " cannot be used to access data of rank " +
std::to_string(r - batch_dims)); std::to_string(r - batch_dims));
} }
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1; if(batch_dims >= q or batch_dims >= r)
std::vector<std::size_t> output_lens(output_lens_size); {
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin()); MIGRAPHX_THROW("GATHERND: rank of an input cannot be less than batch_dims=" +
if(k < r - batch_dims) std::to_string(batch_dims));
}
if(output_ndim < 0)
{
MIGRAPHX_THROW("GATHERND: Indices too large for static data input: k=" +
std::to_string(k));
}
if(migraphx::none_of(inputs, [](auto v) { return v.dynamic(); }))
{
auto indices_lens_iter = i_shape.lens().begin();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape{data_shape.type(), {1}};
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<std::size_t> output_lens(output_ndim);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_lens = data_shape.lens();
std::copy(data_lens.begin() + batch_dims + k,
data_lens.end(),
output_lens.begin() + q - 1);
}
shape output_shape{data_shape.type(), output_lens};
return output_shape;
}
else
{ {
auto data_lens = inputs.front().lens(); // If one or both inputs are dynamic shapes, the output is dynamic.
std::copy( // Make both inputs dynamic to simplify computations.
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1); data_shape = data_shape.to_dynamic();
i_shape = i_shape.to_dynamic();
// A rank 0 output is a scalar
if(output_ndim == 0)
return shape(data_shape.type(), {shape::dynamic_dimension({1, 1, 0})});
// Part of the output shape comes from indices tensor, part from data tensor
std::vector<shape::dynamic_dimension> output_dims(output_ndim);
std::copy(i_shape.dyn_dims().begin(),
i_shape.dyn_dims().begin() + q - 1,
output_dims.begin());
// fill the rest of output shape from data tensor
if(k + batch_dims < r)
{
auto data_dims = data_shape.dyn_dims();
std::copy(data_dims.begin() + batch_dims + k,
data_dims.begin() + r,
output_dims.begin() + q - 1);
}
shape output_shape(data_shape.type(), output_dims);
return output_shape;
} }
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape(); auto indices_shape = indices.get_shape();
......
...@@ -28,44 +28,89 @@ ...@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template <class Derived> template <class Derived>
struct scatternd_op : op_name<Derived> struct scatternd_op : op_name<Derived>
{ {
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this, true}.has(3);
auto r = inputs.front().lens().size(); auto data_shape = inputs.front();
auto q = inputs.at(1).lens().size(); auto index_shape = inputs.at(1);
auto k = inputs.at(1).lens().back(); auto upd_shape = inputs.back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens(); auto r = data_shape.ndim();
auto data_lens = inputs.front().lens(); auto q = index_shape.ndim();
size_t k;
if(index_shape.dynamic())
{
// the rank of the output is a function of k, so k must be fixed.
if(not index_shape.dyn_dims().back().is_fixed())
{
MIGRAPHX_THROW(
"GATHERND: last dimension of indices tensor must be fixed (min=max)");
}
k = index_shape.dyn_dims().back().min;
}
else
k = index_shape.lens().back();
// Checks on the sizes of input tensors
if(q + r != upd_shape.ndim() + k + 1)
MIGRAPHX_THROW("ScatterND: ranks of inputs don't match. " + std::to_string(q) + " + " +
std::to_string(r) + " - " + std::to_string(k) +
" - 1 != " + std::to_string(upd_shape.ndim()));
if(k > r) if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) + MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r)); " is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1))) // Convert all static shape dimensions to dynamic so they can be compared.
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] " // It's possible for some of the 3 inputs to be dynamic shapes and some static,
"++ data.lens[k:r-1]"); // but any dynamic dimension that's compared to a static dimension must be fixed.
auto s = inputs.front(); auto ind_dims = index_shape.to_dynamic().dyn_dims();
if(s.broadcasted()) auto upd_dims = upd_shape.to_dynamic().dyn_dims();
auto data_dims = data_shape.to_dynamic().dyn_dims();
// Check that corresponding portions of tensor shapes match.
if(not(std::equal(ind_dims.begin(), ind_dims.begin() + q - 1, upd_dims.begin()) and
std::equal(data_dims.begin() + k, data_dims.end(), upd_dims.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. Update dimensions must match "
"indices and data.");
if(data_shape.dynamic())
return data_shape;
else if(data_shape.broadcasted())
{ {
return {s.type(), s.lens()}; return {data_shape.type(), data_shape.lens()};
} }
else else
{ {
return s.with_lens(s.lens()); return data_shape.with_lens(data_shape.lens());
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto& self = static_cast<const Derived&>(*this); auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) { visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin()); std::copy(data.begin(), data.end(), output.begin());
...@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived> ...@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto updates_std = shape{updates_shape.type(), updates_shape.lens()}; auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape(); auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back(); auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size(); auto q = indices_shape.ndim();
auto r = output_shape.lens().size(); auto r = dyn_out.computed_shape.ndim();
par_for(updates_shape.elements(), [&](const auto i) { par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i); auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0); std::vector<std::size_t> indices_idx(q, 0);
...@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived> ...@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std::copy(index_start, index_end, out_idx.begin()); std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k); std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]); self.reduction()(output[dyn_out.computed_shape.index(out_idx)], updates[i]);
}); });
}); });
}); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_MEMORY_COLORING);
using instruction_set = std::unordered_set<instruction_ref>;
using instruction_set_map = std::unordered_map<instruction_ref, instruction_set>;
// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template <class F>
void liveness(const module& m, F f)
{
auto implicit_deps = m.calc_implicit_deps();
instruction_set live_set;
auto rp = reverse(m);
for(auto rins : iterator_for(rp)) // NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto ins = std::prev(rins.base());
// Add live variables
auto add_live_variables = [&](const auto& inputs) {
for(auto input : inputs)
{
auto i = instruction::get_output_alias(input);
// Skip if variable comes from parent
if(not m.has_instruction(i))
continue;
live_set.insert(i);
}
};
add_live_variables(ins->inputs());
add_live_variables(implicit_deps[ins]);
// Remove last usage
auto it = live_set.find(ins);
if(it != live_set.end())
{
live_set.erase(it);
f(ins, live_set);
}
}
}
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map build_conflict_table(const module& m, std::string allocation_op)
{
instruction_set_map conflict_table;
liveness(m, [&](auto ins, auto live_set) {
// Skip variables that aren't allocations
if(ins->name() != allocation_op)
return;
// Skip zero allocations
if(ins->get_shape().bytes() == 0)
return;
conflict_table[ins];
for(auto i : live_set)
{
if(i == ins)
continue;
// Skip variables that aren't allocations
if(i->name() != allocation_op)
continue;
// Skip zero allocations
if(i->get_shape().bytes() == 0)
continue;
conflict_table[i].insert(ins);
conflict_table[ins].insert(i);
}
});
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [](auto&& pp) {
return pp.second.count(pp.first) == 0;
}));
return conflict_table;
}
// Check if intervals overlap
bool is_overlap(std::pair<std::size_t, std::size_t> x, std::pair<std::size_t, std::size_t> y)
{
return std::max(x.first, y.first) < std::min(x.second, y.second);
}
struct allocation_segment
{
using segment = std::pair<std::size_t, std::size_t>;
std::unordered_map<instruction_ref, segment> ins2segment;
const segment* add_segment(instruction_ref ins, segment s) { return &(ins2segment[ins] = s); }
const segment* get_segment(instruction_ref ins) const
{
auto it = ins2segment.find(ins);
if(it == ins2segment.end())
return nullptr;
return &it->second;
}
// Remove segment for an instruction
void remove(instruction_ref ins)
{
auto it = ins2segment.find(ins);
if(it != ins2segment.end())
{
ins2segment.erase(it);
}
}
std::size_t max()
{
std::size_t n = 0;
for(auto&& pp : ins2segment)
{
auto seg = pp.second;
n = std::max(n, seg.second);
}
return n;
}
template <class Iterator>
static bool overlaps(Iterator first, Iterator last, const segment& s)
{
return std::any_of(first, last, [&](auto&& t) { return is_overlap(s, t); });
}
static bool overlaps(const std::set<segment>& segments, const segment& s)
{
return overlaps(segments.begin(), segments.end(), s);
}
static auto find_gap(const std::set<segment>& segments, std::size_t n)
{
std::size_t max_end = 0;
return std::adjacent_find(segments.begin(), segments.end(), [&](segment x, segment y) {
if(x.second < max_end)
return false;
max_end = x.second;
if(is_overlap(x, y))
return false;
assert(y.first >= x.second);
auto k = y.first - x.second;
return (k >= n);
});
}
static std::size_t max_type_size(const shape& s)
{
return std::accumulate(
s.sub_shapes().begin(),
s.sub_shapes().end(),
s.type_size(),
[](auto size, const auto& sub) { return std::max(size, max_type_size(sub)); });
}
static std::size_t compute_alignment(instruction_ref ins)
{
auto alignment = max_type_size(ins->get_shape());
// A rough estimate for the total number of elements
auto n = ins->get_shape().bytes() / alignment;
// Check for vectorized alignment
if(n > 4)
{
auto d = n % 4;
if(d == 0)
alignment *= 4;
if(d == 2)
alignment *= 2;
}
return alignment;
}
static segment
next_segment(std::set<segment>& segments, instruction_ref ins, std::size_t alignment)
{
assert(ins->get_shape().bytes() > 0);
// Compute alignment
auto n = 1 + (ins->get_shape().bytes() - 1) / alignment;
assert(n > 0);
auto start = 0;
// Insert at end if it cant fit at the begining
if(segments.empty() or segments.begin()->first <= n)
{
auto it = find_gap(segments, n);
if(it == segments.end())
it = std::max_element(segments.begin(), segments.end(), [&](segment x, segment y) {
return x.second < y.second;
});
if(it != segments.end())
start = it->second;
}
auto s = segment{start, start + n};
assert(not overlaps(segments, s));
segments.insert(s);
return s;
}
static std::unordered_map<instruction_ref, int>
create_allocation_index(const module& m, const instruction_set_map& conflict_table)
{
std::unordered_map<instruction_ref, int> result;
int i = 0;
for(auto ins : iterator_for(m))
{
if(not contains(conflict_table, ins))
continue;
result[ins] = i++;
}
return result;
}
// Build the allocation_color class from the conflict_table
static allocation_segment
build(const module& m, const instruction_set_map& conflict_table, std::size_t alignment)
{
allocation_segment as{};
std::vector<instruction_ref> conflict_queue;
// Add all allocations to the conflict_queue
std::transform(conflict_table.begin(),
conflict_table.end(),
std::back_inserter(conflict_queue),
[](auto&& pp) { return pp.first; });
auto alloc_index = create_allocation_index(m, conflict_table);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std::sort(conflict_queue.begin(), conflict_queue.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(
conflict_table.at(x).size(), x->get_shape().bytes(), alloc_index.at(x));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for(auto parent : conflict_queue)
{
// Sort children by size
std::vector<instruction_ref> children(conflict_table.at(parent).begin(),
conflict_table.at(parent).end());
std::sort(children.begin(), children.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(x->get_shape().bytes(), alloc_index.at(x));
}));
assert(not contains(children, parent));
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
assert(as.get_segment(parent) == nullptr);
as.add_segment(parent, next_segment(segments, parent, alignment));
}
// Reduce the number of segments
for(std::size_t n = 0; n < 3; n++)
{
for(auto parent : conflict_queue)
{
auto children = conflict_table.at(parent);
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
// Get the segment for the parent
const auto* parent_segment = as.get_segment(parent);
assert(parent_segment != nullptr);
auto s = next_segment(segments, parent, alignment);
if(s != *parent_segment and s.second <= as.max())
{
as.add_segment(parent, s);
}
}
}
return as;
}
};
static std::size_t find_max_alignment(const module& m, const std::string& allocation_op)
{
std::size_t alignment = 1;
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
alignment = std::max(allocation_segment::compute_alignment(ins), alignment);
}
return alignment;
}
void memory_coloring::apply(module& m) const
{
const std::size_t alignment = find_max_alignment(m, allocation_op);
auto conflict_table = build_conflict_table(m, allocation_op);
auto as = allocation_segment::build(m, conflict_table, alignment);
// All allocations should have a segment
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
return as.get_segment(pp.first);
}));
// Adjacent allocations should not have overlapping segments
assert(std::none_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
auto* x = as.get_segment(pp.first);
return std::any_of(pp.second.begin(), pp.second.end(), [&](auto ins) {
auto* y = as.get_segment(ins);
assert(x and y);
return is_overlap(*x, *y);
});
}));
// Print out segments
if(enabled(MIGRAPHX_DEBUG_MEMORY_COLORING{}))
{
for(auto&& pp : conflict_table)
{
std::cout << "------- conflict -------" << std::endl;
auto s1 = as.ins2segment.at(pp.first);
std::cout << s1.first << ", " << s1.second << ": ";
m.debug_print(pp.first);
for(auto ins : pp.second)
{
auto s2 = as.ins2segment.at(ins);
std::cout << s2.first << ", " << s2.second << ": ";
m.debug_print(ins);
}
}
}
// Total memory
std::size_t n = as.max() * alignment;
// Replace allocations
auto mem = m.add_parameter("scratch", shape{shape::int8_type, {n}});
for(auto&& [ins, seg] : as.ins2segment)
{
assert(ins->name() == allocation_op);
auto s = ins->get_shape();
std::size_t offset = seg.first * alignment;
assert(offset < n);
m.replace_instruction(ins, op::load{s, offset}, mem);
}
// Replace zero allocation
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
assert(ins->get_shape().bytes() == 0);
m.replace_instruction(ins, op::load{ins->get_shape(), 0}, mem);
}
// Remove scratch parameter if its not used
if(mem->outputs().empty())
{
m.remove_instruction(mem);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -113,7 +113,8 @@ struct onnx_parser ...@@ -113,7 +113,8 @@ struct onnx_parser
void parse_from(std::istream& is, std::string name = ""); void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size); void parse_from(const void* data, std::size_t size);
void parse_graph(module* mod, const onnx::GraphProto& graph); std::vector<instruction_ref>
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const; literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const; literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const; shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
...@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name) ...@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(mm, model.graph()); (void)this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size) ...@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(mm, model.graph()); (void)this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version; return version;
} }
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
...@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::back_inserter(output_ins), std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; }); [&](const auto& name) { return instructions[name]; });
// add the return instuction if(not inlining)
mod->add_return(output_ins); {
// add the return instuction
mod->add_return(output_ins);
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
// remove instructions added in this mod return output_ins;
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
} }
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
...@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if> ...@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!"); " condition input can have only one element!");
} }
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if(args.front()->can_eval())
{
auto cond_arg = args.front()->eval();
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
return parser.parse_graph(mod, then_graph, true);
}
// else branch
else
{
return parser.parse_graph(mod, else_graph, true);
}
}
std::string then_name = info.name + "_if"; std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name); module_ref then_mdl = parser.prog.create_module(then_name);
...@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if> ...@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref else_mdl = parser.prog.create_module(else_name); module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph // parse the then sub_graph
parser.parse_graph(then_mdl, then_graph); (void)parser.parse_graph(then_mdl, then_graph);
// parse_the else sub_graph // parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph); (void)parser.parse_graph(else_mdl, else_graph);
auto then_out_shapes = then_mdl->get_output_shapes(); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes(); auto else_out_shapes = else_mdl->get_output_shapes();
......
...@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop> ...@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref sub_mod = parser.prog.create_module(mod_name); module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph // parse the sub_graph
parser.parse_graph(sub_mod, sub_graph); (void)parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction( auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod}); make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
......
...@@ -39,6 +39,7 @@ namespace migraphx { ...@@ -39,6 +39,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_PASSES);
void validate_pass(module& mod, const pass& p, tracer trace) void validate_pass(module& mod, const pass& p, tracer trace)
{ {
...@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager ...@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
assert(mod); assert(mod);
timer ts{};
using seconds = std::chrono::duration<double>;
trace("Module: ", mod->name(), ", Pass: ", p.name());
const double t1 = ts.record<seconds>();
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
p.apply(*this); if(enabled(MIGRAPHX_TIME_PASSES{}))
{
using milliseconds = std::chrono::duration<double, std::milli>;
auto ms = time<milliseconds>([&] { p.apply(*this); });
std::cout << p.name() << ": " << ms << "ms\n";
}
else
{
p.apply(*this);
}
trace(*mod); trace(*mod);
validate_pass(*mod, p, *t); validate_pass(*mod, p, *t);
const double t2 = ts.record<seconds>();
trace("Pass: ", p.name(), " completed in (s): ", (t2 - t1));
} }
}; };
......
...@@ -336,7 +336,8 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -336,7 +336,8 @@ std::vector<argument> generic_eval(const module* mod,
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape()) if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
{ {
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name); "} for parameter: " + param_name +
" should be: " + to_string(ins->get_shape()));
} }
return param; return param;
})); }));
......
...@@ -691,7 +691,7 @@ TEST_CASE(test38) ...@@ -691,7 +691,7 @@ TEST_CASE(test38)
auto p83 = m.add_instruction(pass_op{}, p78, p77); auto p83 = m.add_instruction(pass_op{}, p78, p77);
m.add_instruction(pass_op{}, output, p83, p63); m.add_instruction(pass_op{}, output, p83, p63);
run_pass(m); run_pass(m);
CHECK(m.get_parameter_shape("scratch").bytes() == 7225344); // Optimal solution is 6422528 CHECK(m.get_parameter_shape("scratch").bytes() == 6422528);
CHECK(no_allocate(m)); CHECK(no_allocate(m));
} }
...@@ -729,7 +729,7 @@ TEST_CASE(test39) ...@@ -729,7 +729,7 @@ TEST_CASE(test39)
run_pass(*smod); run_pass(*smod);
} }
CHECK(mm->get_parameter_shape("scratch").bytes() == 4); CHECK(mm->get_parameter_shape("scratch").bytes() == 1);
CHECK(then_mod->get_parameter_shape("scratch").bytes() == 24); CHECK(then_mod->get_parameter_shape("scratch").bytes() == 24);
CHECK(else_mod->get_parameter_shape("scratch").bytes() == 24); CHECK(else_mod->get_parameter_shape("scratch").bytes() == 24);
CHECK(no_allocate(*mm)); CHECK(no_allocate(*mm));
...@@ -3374,7 +3374,7 @@ TEST_CASE(rnn_dom) ...@@ -3374,7 +3374,7 @@ TEST_CASE(rnn_dom)
m.add_instruction(pass_op{}, moutput, mx250, mx249, mx248); m.add_instruction(pass_op{}, moutput, mx250, mx249, mx248);
run_pass(m); run_pass(m);
CHECK(m.get_parameter_shape("scratch").bytes() == 1600); CHECK(m.get_parameter_shape("scratch").bytes() == 1824); // Optimal is 1600
CHECK(no_allocate(m)); CHECK(no_allocate(m));
CHECK(is_disjoint({mx0, mx8})); CHECK(is_disjoint({mx0, mx8}));
CHECK(is_disjoint({mx0, mx8})); CHECK(is_disjoint({mx0, mx8}));
...@@ -3790,4 +3790,23 @@ TEST_CASE(literal_test) ...@@ -3790,4 +3790,23 @@ TEST_CASE(literal_test)
CHECK(lit == result); CHECK(lit == result);
} }
TEST_CASE(test_tuple)
{
migraphx::module m;
auto s1 = migraphx::shape{migraphx::shape::float_type, {8}};
auto s2 = migraphx::shape{migraphx::shape::half_type, {10}};
auto s = migraphx::shape{{s1, s2}};
auto a1 = add_alloc(m, s);
auto m1 = m.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(m, {migraphx::shape::float_type, {4}});
m.add_instruction(pass_op{}, a2, m1);
run_pass(m);
CHECK(m.get_parameter_shape("scratch").bytes() == 68);
CHECK(no_allocate(m));
CHECK(is_disjoint({a1, a2}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -2132,6 +2132,19 @@ def gathernd_test(): ...@@ -2132,6 +2132,19 @@ def gathernd_test():
return ([node], [x, i], [y]) return ([node], [x, i], [y])
@onnx_test()
def gathernd_dyn_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2])
i = helper.make_tensor_value_info('indices', TensorProto.INT64, [2, 2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2])
node = onnx.helper.make_node('GatherND',
inputs=['data', 'indices'],
outputs=['y'])
return ([node], [x, i], [y])
@onnx_test() @onnx_test()
def gathernd_batch_dims_test(): def gathernd_batch_dims_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2]) x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 2, 2])
...@@ -2498,6 +2511,58 @@ def if_else_test(): ...@@ -2498,6 +2511,58 @@ def if_else_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res], [xt_tensor, yt_tensor])
@onnx_test()
def if_else_test_inlined():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 3]) [2, 3])
...@@ -2547,6 +2612,149 @@ def if_else_test(): ...@@ -2547,6 +2612,149 @@ def if_else_test():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor]) return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test()
def if_then_else_multi_output_shapes_inlined_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out2 = onnx.helper.make_tensor_value_info('then_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
then_add_node2 = onnx.helper.make_node('Add',
inputs=['x', 'x'],
outputs=['then_out2'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
else_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node, then_add_node2],
'then_body', [], [then_out, then_out2])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
'else_body', [], [else_out, else_out2])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
res2 = onnx.helper.make_tensor_value_info('res2', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res1', 'res2'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res1, res2], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test()
def if_then_else_multi_output_shapes_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out2 = onnx.helper.make_tensor_value_info('then_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3, 1).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
then_add_node2 = onnx.helper.make_node('Add',
inputs=['x', 'x'],
outputs=['then_out2'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
else_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node, then_add_node2],
'then_body', [], [then_out, then_out2])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
'else_body', [], [else_out, else_out2])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
res2 = onnx.helper.make_tensor_value_info('res2', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res1', 'res2'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res1, res2], [xt_tensor, yt_tensor])
@onnx_test() @onnx_test()
def if_literal_test(): def if_literal_test():
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
...@@ -2807,6 +3015,59 @@ def if_then_test(): ...@@ -2807,6 +3015,59 @@ def if_then_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res], [xt_tensor, yt_tensor])
@onnx_test()
def if_then_test_inlined():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 3]) [2, 3])
...@@ -5707,6 +5968,24 @@ def scatternd_test(): ...@@ -5707,6 +5968,24 @@ def scatternd_test():
return ([node], [data, indices, updates], [output]) return ([node], [data, indices, updates], [output])
@onnx_test()
def scatternd_dyn_test():
data = helper.make_tensor_value_info('data', TensorProto.FLOAT,
[None, 2, 2])
indices = helper.make_tensor_value_info('indices', TensorProto.INT64,
[None, 1, 2])
updates = helper.make_tensor_value_info('updates', TensorProto.FLOAT,
[None, 1, 2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[None, 2, 2])
node = onnx.helper.make_node('ScatterND',
inputs=['data', 'indices', 'updates'],
outputs=['output'])
return ([node], [data, indices, updates], [output])
@onnx_test() @onnx_test()
def selu_test(): def selu_test():
x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3]) x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [2, 3])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment