Commit 9a758ec4 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/optimize' into ck-gsg

parents ffe2c0cc 913ae362
...@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -87,7 +87,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@c0723a7e50043d973cb73ae51dc30d36679ee7e5 -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@0f38fb33f518b53b94b541feb9b079668c5518e8 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
} }
auto offset = s1.ndim() - s0.ndim(); auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims()); std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform( std::transform(
s0.dyn_dims().cbegin(), s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(), s0.dyn_dims().cend(),
...@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{ {
return a; return a;
} }
else if(a == one_dyn_dim or b == one_dyn_dim) else if(a == 1 or b == 1)
{ {
// setting opt to 0, may need to be changed // setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0}; return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
......
...@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const ...@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not contains({"undefined", "identity", "allocate"}, i->name())) not i->is_undefined())
continue; continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited; std::unordered_set<instruction_ref> visited;
......
...@@ -30,23 +30,31 @@ namespace migraphx { ...@@ -30,23 +30,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
T generic_read_file(const std::string& filename) T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{ {
std::ifstream is(filename, std::ios::binary | std::ios::ate); std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg(); if(nbytes == 0)
if(size < 1) {
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename); MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(0, std::ios::beg); is.seekg(offset, std::ios::beg);
T buffer(size, 0); T buffer(nbytes, 0);
if(not is.read(&buffer[0], size)) if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename); MIGRAPHX_THROW("Error reading file: " + filename);
return buffer; return buffer;
} }
std::vector<char> read_buffer(const std::string& filename) std::vector<char> read_buffer(const std::string& filename, size_t offset, size_t nbytes)
{ {
return generic_read_file<std::vector<char>>(filename); return generic_read_file<std::vector<char>>(filename, offset, nbytes);
} }
std::string read_string(const std::string& filename) std::string read_string(const std::string& filename)
......
...@@ -198,7 +198,7 @@ struct check_shapes ...@@ -198,7 +198,7 @@ struct check_shapes
*/ */
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(not this->same([](const shape& s) { return s.max_lens().size(); })) if(not this->same([](const shape& s) { return s.ndim(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> read_buffer(const std::string& filename); std::vector<char> read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0);
std::string read_string(const std::string& filename); std::string read_string(const std::string& filename);
void write_buffer(const std::string& filename, const char* buffer, std::size_t size); void write_buffer(const std::string& filename, const char* buffer, std::size_t size);
......
...@@ -121,6 +121,8 @@ struct instruction ...@@ -121,6 +121,8 @@ struct instruction
bool can_eval() const; bool can_eval() const;
bool is_undefined() const;
argument eval(bool check_eval = true) const; argument eval(bool check_eval = true) const;
void finalize(context& ctx); void finalize(context& ctx);
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -56,13 +57,21 @@ struct argmax ...@@ -56,13 +57,21 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto lens = inputs[0].lens(); const auto& s0 = inputs[0];
if(s0.dynamic())
{
auto dyn_dims = s0.dyn_dims();
dyn_dims[axis] = {1, 1, 0};
return {shape::int64_type, dyn_dims};
}
else
{
auto lens = s0.lens();
lens[axis] = 1; lens[axis] = 1;
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
}
template <class T> template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
...@@ -79,19 +88,18 @@ struct argmax ...@@ -79,19 +88,18 @@ struct argmax
max_index = i; max_index = i;
} }
} }
return max_index; return max_index;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto batch_item_num = args.front().get_shape().lens()[axis]; auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i); auto data_idx = dyn_out.computed_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num); output[i] = this->calc_argmax(input, data_idx, batch_item_num);
}); });
}); });
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gemm.hpp> #include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -38,41 +39,69 @@ struct dot ...@@ -38,41 +39,69 @@ struct dot
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_type().has(2); check_shapes{inputs, *this, true}.same_type().same_ndims().has(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(not std::all_of( if(not std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.ndim() >= 2; }))
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("DOT: dot only accepts operands with 2 or more dimensions ");
} }
if(a.dynamic() or b.dynamic())
// only handle the case that the batch size of a and b are the same {
auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2,
s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend()))
{
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1;
if(s0.dyn_dims()[dim_1] != s1.dyn_dims()[dim_0])
{
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
auto out_dyn_dims = s0.dyn_dims();
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims};
}
else
{
// only handle the case that all the dimensions except the last two are the same
if(not std::equal( if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: static outer dimensions of A and B mismatch: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
} }
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.ndim() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.ndim() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
{ {
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: static inner dimensions do not match: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
} }
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens}; return {t, out_lens};
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result = argument{output_shape}; argument result = argument{dyn_out.computed_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); }); [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result; return result;
......
...@@ -55,17 +55,47 @@ struct flatten ...@@ -55,17 +55,47 @@ struct flatten
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this, true}.has(1);
auto&& lens = inputs.front().lens(); auto s = inputs[0];
auto x = if(s.dynamic())
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); {
auto y = auto min_lens = s.min_lens();
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); auto max_lens = s.max_lens();
return {inputs.at(0).type(), {x, y}}; auto opt_lens = s.opt_lens();
// If any of the opt values is 0, output opt will be 0
shape::dynamic_dimension x = {
std::accumulate(
min_lens.begin(), min_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(
max_lens.begin(), max_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(opt_lens.begin(),
opt_lens.begin() + axis,
std::size_t{1},
std::multiplies<>{})};
shape::dynamic_dimension y = {
std::accumulate(
min_lens.begin() + axis, min_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate(
max_lens.begin() + axis, max_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate(
opt_lens.begin() + axis, opt_lens.end(), std::size_t{1}, std::multiplies<>{}),
};
return {s.type(), {x, y}};
}
else
{
check_shapes{inputs, *this}.standard();
auto&& lens = s.lens();
auto x = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {s.type(), {x, y}};
}
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -53,15 +53,15 @@ struct softmax ...@@ -53,15 +53,15 @@ struct softmax
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
if(inputs.at(0).packed()) auto s0 = inputs[0];
if(s0.dynamic() or s0.packed())
{ {
return inputs.at(0); return s0;
} }
else else
{ {
auto lens = inputs.at(0).lens(); return {s0.type(), s0.lens()};
return {inputs.at(0).type(), lens};
} }
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,14 +55,46 @@ struct squeeze ...@@ -54,14 +55,46 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(input_shape.dynamic())
{
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return input_shape.dyn_dims()[axis] != 1;
}))
{
MIGRAPHX_THROW(
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
std::copy_if(input_shape.dyn_dims().cbegin(),
input_shape.dyn_dims().cend(),
std::back_inserter(dyn_dims),
[&](auto dd) { return dd != 1; });
}
else
{
for(auto i : range(input_shape.ndim()))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
}
}
return {input_shape.type(), dyn_dims};
}
else
{
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides(); auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides; std::vector<std::size_t> new_strides;
...@@ -96,10 +129,11 @@ struct squeeze ...@@ -96,10 +129,11 @@ struct squeeze
return shape{type, new_lens, new_strides}; return shape{type, new_lens, new_strides};
} }
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -29,11 +29,20 @@ ...@@ -29,11 +29,20 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -56,8 +65,33 @@ struct unsqueeze ...@@ -56,8 +65,33 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(input_shape.dynamic())
{
if(not steps.empty())
{
MIGRAPHX_THROW("UNSQUEEZE_dyn: nonempty steps attribute");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
auto new_ndim = input_shape.ndim() + axes.size();
std::size_t k = 0;
for(auto i : range(new_ndim))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
dyn_dims.push_back({1, 1, 0});
}
else
{
dyn_dims.push_back(input_shape.dyn_dims().at(k++));
}
}
return {input_shape.type(), dyn_dims};
}
else
{
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides(); auto old_strides = input_shape.strides();
...@@ -110,9 +144,10 @@ struct unsqueeze ...@@ -110,9 +144,10 @@ struct unsqueeze
} }
return shape{type, new_lens, new_strides}; return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -39,7 +39,7 @@ struct module_pass_manager; ...@@ -39,7 +39,7 @@ struct module_pass_manager;
struct optimize struct optimize
{ {
std::string name() const { return "optimize"; } std::string name() const { return "optimize"; }
void apply(module_pass_manager& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -101,6 +101,12 @@ struct shape ...@@ -101,6 +101,12 @@ struct shape
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x); friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
// compare to fixed std::size_t dimension
friend bool operator==(const dynamic_dimension& x, const std::size_t& y);
friend bool operator==(const std::size_t& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
......
...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod) ...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std::replace(module_args.begin(), module_args.end(), old, new_mod); std::replace(module_args.begin(), module_args.end(), old, new_mod);
} }
bool instruction::is_undefined() const
{
if(op.name() == "undefined")
{
return true;
}
else if(this->inputs().empty())
{
return false;
}
else
{
return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) {
return arg->is_undefined();
});
}
}
bool instruction::can_eval() const bool instruction::can_eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
......
...@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(not t.external_data().empty()) auto type = get_type(t.data_type());
shape tensor_shape(type, dims);
auto external_data = t.external_data();
if(not external_data.empty())
{
const std::string& data_file = external_data.at(0).value();
size_t num_data_fields = external_data.size();
size_t offset = 0;
size_t nbytes = tensor_shape.bytes();
if(num_data_fields > 1) // if offset field is present
{
offset = std::stoul(t.external_data().at(1).value());
}
if(num_data_fields > 2) // if nbytes field is present
{ {
const std::string& data_file = t.external_data().at(0).value(); nbytes = std::stoul(t.external_data().at(2).value());
auto raw_buffer = read_buffer(path + "/" + data_file); }
auto raw_buffer = read_buffer(path + "/" + data_file, offset, nbytes);
std::string s(raw_buffer.begin(), raw_buffer.end()); std::string s(raw_buffer.begin(), raw_buffer.end());
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
......
...@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -92,7 +92,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -117,7 +117,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -129,7 +129,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -195,14 +195,14 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -398,7 +398,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -423,7 +423,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -434,7 +434,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// intial hidden state // intial hidden state
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -501,14 +501,14 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// intial hidden state // intial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -784,7 +784,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process sequence length // process sequence length
instruction_ref seq_lens = m.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) and not args[4]->is_undefined())
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
...@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -813,7 +813,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = m.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias_forward = m.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
...@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -824,7 +824,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process intial hidden state, it is the 6th argument // process intial hidden state, it is the 6th argument
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 and not args[5]->is_undefined())
{ {
ih_forward = m.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
...@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -840,7 +840,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process initial cell value // process initial cell value
instruction_ref ic_forward{}; instruction_ref ic_forward{};
instruction_ref ic_reverse{}; instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 and not args[6]->is_undefined())
{ {
ic_forward = m.insert_instruction( ic_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
...@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -856,7 +856,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward = m.end(); instruction_ref pph_forward = m.end();
instruction_ref pph_reverse = m.end(); instruction_ref pph_reverse = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 and not args[7]->is_undefined())
{ {
pph_forward = m.insert_instruction( pph_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
...@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -940,14 +940,14 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// bias // bias
instruction_ref bias = m.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 and not args[3]->is_undefined())
{ {
bias = args[3]; bias = args[3];
} }
// initial hidden state // initial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 and not args[5]->is_undefined())
{ {
ih = args[5]; ih = args[5];
} }
...@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -958,7 +958,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// initial cell value // initial cell value
instruction_ref ic{}; instruction_ref ic{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 and not args[6]->is_undefined())
{ {
ic = args[6]; ic = args[6];
} }
...@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -969,7 +969,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
// process weight of the peephole // process weight of the peephole
instruction_ref pph = m.end(); instruction_ref pph = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 and not args[7]->is_undefined())
{ {
pph = args[7]; pph = args[7];
} }
......
...@@ -521,6 +521,14 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) ...@@ -521,6 +521,14 @@ std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
return os; return os;
} }
bool operator==(const shape::dynamic_dimension& x, const std::size_t& y)
{
return x.min == y and x.max == y;
}
bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; }
bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); }
bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); }
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
if(x.dynamic() and y.dynamic()) if(x.dynamic() and y.dynamic())
......
...@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options.push_back("-fno-gpu-rdc"); options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3")); options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat"); options.push_back("-Wno-cuda-compat");
options.push_back("--cuda-gpu-arch=" + arch); options.push_back("--offload-arch=" + arch);
prog.compile(options); prog.compile(options);
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
...@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
} }
else if(is_hip_clang_compiler()) else if(is_hip_clang_compiler())
{ {
params += " --cuda-gpu-arch=" + arch; params += " --offload-arch=" + arch;
params += " --cuda-device-only"; params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " "; params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment