Commit be5f3539 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge develop branch changes

parents 7e3bdc34 ebfe9735
...@@ -35,14 +35,28 @@ struct multibroadcast ...@@ -35,14 +35,28 @@ struct multibroadcast
auto input = inputs.at(0); auto input = inputs.at(0);
if(input.lens().empty()) if(input.lens().empty())
MIGRAPHX_THROW("inputs dimensions should be > 0"); {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input.lens().size() > output_lens.size()) if(input.lens().size() > output_lens.size())
MIGRAPHX_THROW("inputs dimensions should <= output size"); {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size(); auto offset = output_lens.size() - input.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
}
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{ {
if(output_lens[i + offset] == input.lens()[i]) if(output_lens[i + offset] == input.lens()[i])
{ {
......
...@@ -48,51 +48,21 @@ struct pooling ...@@ -48,51 +48,21 @@ struct pooling
assert(lengths[0] <= (input.lens()[2] + 2 * padding[0])); assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
assert(lengths[1] <= (input.lens()[3] + 2 * padding[1])); assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
if(padding_mode == default_) return {t,
{
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(
input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(
input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) +
1)),
}};
}
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
input.lens()[1],
ceil_divide<std::size_t>(input.lens()[2], stride[0]),
ceil_divide<std::size_t>(input.lens()[3], stride[1])}};
}
else if(padding_mode == valid)
{
return {
t,
{ {
input.lens()[0], input.lens()[0],
input.lens()[1], input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)), floor_divide<std::ptrdiff_t>(input.lens()[2] + 2 * padding[0] - lengths[0],
stride[0]) +
1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)), floor_divide<std::ptrdiff_t>(input.lens()[3] + 2 * padding[1] - lengths[1],
stride[1]) +
1)),
}}; }};
}
else
{
MIGRAPHX_THROW("Invalid padding mode");
}
} }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reduce_sum"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
lens[axis] = 1;
return {s.type(), lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](auto&& in_idx) {
auto out_idx = in_idx;
for(auto axis : axes)
out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp> #include <migraphx/op/acos.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp> #include <migraphx/op/asin.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
...@@ -23,6 +25,7 @@ ...@@ -23,6 +25,7 @@
#include <migraphx/op/div.hpp> #include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp> #include <migraphx/op/elu.hpp>
#include <migraphx/op/erf.hpp>
#include <migraphx/op/exp.hpp> #include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
...@@ -45,6 +48,7 @@ ...@@ -45,6 +48,7 @@
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
......
...@@ -2,13 +2,24 @@ ...@@ -2,13 +2,24 @@
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP #define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <utility> #include <utility>
#include <cstdint>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline std::size_t calculate_padding(std::size_t weight_dim, std::size_t dilation) inline void calculate_padding(int64_t idx,
std::vector<int64_t>& pads,
int64_t input_dim,
int64_t stride,
int64_t dilation,
int64_t weight_dim)
{ {
return (dilation * (weight_dim - 1)) / 2; int64_t output_dim = input_dim / stride;
int64_t pad = std::max(static_cast<int64_t>(0),
(output_dim - 1) * stride + dilation * weight_dim - input_dim);
pads[idx] = pad / 2;
pads[idx + 2] = pad - pad / 2;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x) ...@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x); return std::find(c.begin(), c.end(), x);
} }
struct empty
{
};
} // namespace detail } // namespace detail
template <class C, class T> template <class C, class T>
...@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return std::all_of(c.begin(), c.end(), p); return std::all_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool all_of(detail::empty, const Predicate&)
{
return true;
}
template <class C, class Predicate> template <class C, class Predicate>
bool any_of(const C& c, const Predicate& p) bool any_of(const C& c, const Predicate& p)
{ {
...@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
return std::any_of(c.begin(), c.end(), p); return std::any_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool any_of(detail::empty, const Predicate&)
{
return false;
}
template <class C, class Predicate> template <class C, class Predicate>
bool none_of(const C& c, const Predicate& p) bool none_of(const C& c, const Predicate& p)
{ {
...@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
return std::none_of(c.begin(), c.end(), p); return std::none_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool none_of(detail::empty, const Predicate&)
{
return true;
}
template <class Range, class Iterator> template <class Range, class Iterator>
void copy(Range&& r, Iterator it) void copy(Range&& r, Iterator it)
{ {
......
...@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs) ...@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs)
}; };
} }
template <class T>
auto visit_all(const std::vector<T>& x)
{
auto&& s = x.front().get_shape();
if(!std::all_of(
x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
std::vector<tensor_view<type>> result;
std::transform(x.begin(), x.end(), std::back_inserter(result), [&](const auto& y) {
return make_view(y.get_shape(), as.from(y.data()));
});
v(result);
});
};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT ...@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
template <bool B> template <bool B>
using bool_c = std::integral_constant<bool, B>; using bool_c = std::integral_constant<bool, B>;
template <int N> #define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
struct requires_enum #define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
{
enum e
{
a = 0
};
};
#define MIGRAPHX_REQUIRES_CAT(x, y) x##y #define MIGRAPHX_REQUIRES_VAR() MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__)
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#else #else
#if 0 #define MIGRAPHX_REQUIRES(...) \
// TODO: This currently crashed on clang bool MIGRAPHX_REQUIRES_VAR() = true, \
#define MIGRAPHX_REQUIRES(...) \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \ int>::type = 0
PrivateRequires, \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -99,6 +99,8 @@ struct shape ...@@ -99,6 +99,8 @@ struct shape
/// Map element index to space index /// Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const;
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
......
...@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f) ...@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
inline std::string to_upper(std::string s) { return transform_string(std::move(s), ::toupper); } inline std::string to_upper(std::string s) { return transform_string(std::move(s), ::toupper); }
inline std::string to_lower(std::string s) { return transform_string(std::move(s), ::tolower); }
inline bool starts_with(const std::string& value, const std::string& prefix) inline bool starts_with(const std::string& value, const std::string& prefix)
{ {
if(prefix.size() > value.size()) if(prefix.size() > value.size())
......
...@@ -19,7 +19,7 @@ rocm_install_targets( ...@@ -19,7 +19,7 @@ rocm_install_targets(
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraphx_onnx) target_link_libraries(read_onnx migraphx_cpu migraphx_onnx)
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
......
...@@ -40,6 +40,7 @@ struct onnx_parser ...@@ -40,6 +40,7 @@ struct onnx_parser
add_generic_op("Sigmoid", op::sigmoid{}); add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{}); add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{}); add_generic_op("Exp", op::exp{});
add_generic_op("Erf", op::erf{});
add_generic_op("Log", op::log{}); add_generic_op("Log", op::log{});
// disable dropout for inference // disable dropout for inference
add_generic_op("Dropout", op::identity{}); add_generic_op("Dropout", op::identity{});
...@@ -63,6 +64,8 @@ struct onnx_parser ...@@ -63,6 +64,8 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
...@@ -93,6 +96,7 @@ struct onnx_parser ...@@ -93,6 +96,7 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm); add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -100,6 +104,7 @@ struct onnx_parser ...@@ -100,6 +104,7 @@ struct onnx_parser
void init_actv_func() void init_actv_func()
{ {
// Support name format of all lower case or the first letter capital
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{})); map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{})); map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{})); map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
...@@ -181,7 +186,15 @@ struct onnx_parser ...@@ -181,7 +186,15 @@ struct onnx_parser
s0.end(), s0.end(),
s1.begin() + offset, s1.begin() + offset,
out_lens.begin() + offset, out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens; return out_lens;
} }
...@@ -265,6 +278,60 @@ struct onnx_parser ...@@ -265,6 +278,60 @@ struct onnx_parser
return prog.add_instruction(op::logsoftmax{axis}, std::move(args)); return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
} }
instruction_ref parse_argmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmax{axis}, std::move(args));
}
}
instruction_ref parse_argmin(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
}
}
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -352,7 +419,8 @@ struct onnx_parser ...@@ -352,7 +419,8 @@ struct onnx_parser
{ {
// insert zeros for pad op (args[0] has 4 dims) // insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0); l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
l0);
} }
else else
{ {
...@@ -870,7 +938,9 @@ struct onnx_parser ...@@ -870,7 +938,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
...@@ -961,7 +1031,9 @@ struct onnx_parser ...@@ -961,7 +1031,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
// need 4 activation functions // need 4 activation functions
...@@ -1088,7 +1160,9 @@ struct onnx_parser ...@@ -1088,7 +1160,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
// need 6 activation functions for bidirectional directions // need 6 activation functions for bidirectional directions
...@@ -1214,6 +1288,40 @@ struct onnx_parser ...@@ -1214,6 +1288,40 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
instruction_ref parse_reduce_sum(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
}
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -101,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum(); t = as.type_enum();
n = sizeof(as()); n = sizeof(as());
} }
}); });
if(n == 0)
{
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
}
auto strides = info.strides; auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0; return n > 0 ? i / n : 0;
......
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
...@@ -204,17 +217,19 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -204,17 +217,19 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr); auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih_lens = sih->get_shape().lens();
// bias // bias
instruction_ref bb{};
if(bias != prog.end()) if(bias != prog.end())
{ {
long hs = r->get_shape().lens()[2]; long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb); auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape().lens()}, b); bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -228,20 +243,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -228,20 +243,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end()) if(bias != prog.end())
{ {
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias); xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
}
else
{
ht = xt_ht;
} }
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
// apply activation function // apply activation function
ht = prog.insert_instruction(ins, actv_func, ht); auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
sih = ht; sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length, // add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions // axis 1 for num_directions
...@@ -485,62 +495,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -485,62 +495,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]}); migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1); std::vector<float> data(s.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{s, data}); auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw); // r slide to two part, zr and h
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw); auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh); auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
size_t bs = ih->get_shape().lens()[1];
// bias // bias
instruction_ref brcst_bz{}; instruction_ref bwb{};
instruction_ref brcst_br{}; instruction_ref brb_zr{};
instruction_ref brcst_wbh{}; instruction_ref brb_h{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto broadcast_lens = sih->get_shape().lens(); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, wbh); auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brb_zr = prog.insert_instruction(
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bh);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -549,56 +538,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -549,56 +538,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz); auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end()) if(bias != prog.end())
{ {
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz); xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
} }
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr); auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr); auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end()) auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
{ auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
} auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r); auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
instruction_ref xht_h; auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{};
if(linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh); hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
}
else
{
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
} }
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); instruction_ref ht1_rh{};
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh); ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
} }
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); else
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh); ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
} }
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
} }
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
...@@ -913,35 +904,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -913,35 +904,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx::shape r_shape = r->get_shape(); migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]); long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1];
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto tran_wi = prog.insert_instruction(ins, op::transpose{perm}, wi);
auto wo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wo = prog.insert_instruction(ins, op::transpose{perm}, wo);
auto wf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wf = prog.insert_instruction(ins, op::transpose{perm}, wf);
auto wc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sw); // r matrix, squeeze and transpose
auto tran_wc = prog.insert_instruction(ins, op::transpose{perm}, wc); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
// r matrix
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto ri = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
auto tran_ri = prog.insert_instruction(ins, op::transpose{perm}, ri);
auto ro = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_ro = prog.insert_instruction(ins, op::transpose{perm}, ro);
auto rf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rf = prog.insert_instruction(ins, op::transpose{perm}, rf);
auto rc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sr);
auto tran_rc = prog.insert_instruction(ins, op::transpose{perm}, rc);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
...@@ -951,40 +923,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -951,40 +923,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic_lens = sic->get_shape().lens(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref bi_brcst{}; instruction_ref wrb{};
instruction_ref bo_brcst{};
instruction_ref bf_brcst{};
instruction_ref bc_brcst{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi); auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bi);
wrb = prog.insert_instruction(
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bc);
} }
// peep hole // peep hole
instruction_ref pphi_brcst{}; instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{}; instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{}; instruction_ref pphf_brcst{};
if(pph != prog.end()) if(pph != prog.end())
{ {
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
...@@ -1004,44 +959,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1004,44 +959,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst); xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
} }
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ft_before_actv =
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
auto ct_before_actv =
prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic); auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
} }
if(bias != prog.end()) auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
{
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
}
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv); auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
if(bias != prog.end())
{
ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst);
}
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct // equation Ct = ft (.) Ct-1 + it (.) ct
...@@ -1050,19 +992,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1050,19 +992,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct); auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt; last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
} }
if(bias != prog.end())
{
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv); auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct) // Ht = ot (.) h(Ct)
......
...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const ...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const
return result; return result;
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const
{
assert(this->standard());
std::vector<std::size_t> indices(lens().size());
std::transform(strides().begin(),
strides().end(),
lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
return indices;
}
bool shape::packed() const { return this->elements() == this->element_space(); } bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::transposed() const bool shape::transposed() const
......
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(instruction_ref ins) const auto& reshaper_names()
{ {
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
...@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins) ...@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze" "unsqueeze"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()); return names;
} }
bool is_transpose_output(instruction_ref ins) bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); }
{
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
}
instruction_ref find_transpose_input(instruction_ref ins) instruction_ref find_transpose_input(instruction_ref ins)
{ {
...@@ -42,62 +38,189 @@ instruction_ref find_transpose_input(instruction_ref ins) ...@@ -42,62 +38,189 @@ instruction_ref find_transpose_input(instruction_ref ins)
return ins; return ins;
} }
void simplify_reshapes::apply(program& p) const auto get_transpose_dims(instruction_ref ins)
{ {
auto end = std::prev(p.end()); return any_cast<const op::transpose&>(ins->get_operator()).dims;
for(auto ins : iterator_for(p)) }
std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
std::vector<int64_t> result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0; i < dims.size(); i++)
{ {
if(ins == end and ins->name() == "contiguous") result[i] = dims[permutation[i]];
continue; }
// Skip possible dead instructions return result;
if(ins->outputs().empty() and ins != end) }
continue;
if(is_reshaper(ins)) bool is_no_transpose(const std::vector<int64_t>& dims)
{
if(dims.empty())
return true;
if(dims.front() != 0)
return false;
return std::adjacent_find(
dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
}
template <class Vector, class Op>
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
struct find_reshaper
{
auto matcher() const
{
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{ {
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper)) assert(!reshapes.back()->inputs().empty());
continue; assert(p.has_instruction(reshapes.back()->inputs().front()));
// Gather reshapes auto input = reshapes.back()->inputs().front();
std::vector<instruction_ref> reshapes{ins}; reshapes.push_back(input);
while(is_reshaper(reshapes.back())) }
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{ {
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start); return i->get_shape() == (*start)->get_shape() and i != (*start);
}); });
if(last != reshapes.rend()) if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{ {
p.replace_instruction(r.first, r.second); r = std::make_pair(*start, *last);
break;
} }
} }
else if(ins->name() == "transpose") if(r.first != r.second)
{
p.replace_instruction(r.first, r.second);
}
}
};
struct find_nop_reshapes
{
auto matcher() const
{
auto reshapes = reshaper_names();
reshapes.insert("transpose");
reshapes.insert("slice");
return match::name(reshapes)(match::same_shape(match::arg(0)));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front());
}
};
struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto x = ins;
auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0);
do
{
dims = reorder_dims(get_transpose_dims(t), dims);
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
return;
if(is_no_transpose(dims))
{ {
if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
do
{
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
p.replace_instruction(ins, t->inputs().front()); p.replace_instruction(ins, t->inputs().front());
} }
else
{
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
}
}
};
struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::same_input_shapes(),
match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto s = ins->inputs().front()->get_shape();
assert(s.transposed());
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutation = invert_permutation(permutation);
op.axis = ipermutation[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
if(i->name() == "transpose" and i->inputs().front()->get_shape().standard())
return i->inputs().front();
return p.insert_instruction(ins, op::transpose{permutation}, i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
};
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p,
ins,
find_nop_reshapes{},
find_reshaper{},
find_transpose{},
find_concat_transpose{});
} }
} }
......
...@@ -2,7 +2,19 @@ ...@@ -2,7 +2,19 @@
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
...@@ -650,18 +662,11 @@ struct cpu_softmax ...@@ -650,18 +662,11 @@ struct cpu_softmax
std::string name() const { return "cpu::softmax"; } std::string name() const { return "cpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(T idx, shape& batch_shape, int axis) const
{
idx[axis] = 0;
return batch_shape.index(idx);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
...@@ -669,26 +674,33 @@ struct cpu_softmax ...@@ -669,26 +674,33 @@ struct cpu_softmax
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis); par_for(batch_shape.elements(), [&](auto i) {
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end())); auto idx = batch_shape.multi(i);
}); for(std::size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) = idx[op.axis] = j;
std::exp(input(idx.begin(), idx.end()) - batch_max[index]); std::size_t index = output_shape.index(idx);
}); output[index] = std::exp(input[index] - batch_max[i]);
}
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for(std::size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += output(idx.begin(), idx.end()); batch_sum[i] += output(idx.begin(), idx.end());
}); }
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) /= batch_sum[index]; idx[op.axis] = j;
output(idx.begin(), idx.end()) /= batch_sum[i];
}
}); });
}); });
...@@ -708,49 +720,50 @@ struct cpu_logsoftmax ...@@ -708,49 +720,50 @@ struct cpu_logsoftmax
std::string name() const { return "cpu::logsoftmax"; } std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(T idx, const shape& batch_shape, int axis) const
{
idx[axis] = 0;
return batch_shape.index(idx);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::size_t n_dims = batch_lens[op.axis];
batch_lens[op.axis] = 1; batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
// use a parallel implementation to acheive better performance
// one thread for one batch
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) { par_for(batch_shape.elements(), [&](auto i) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); auto idx = batch_shape.multi(i);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index]; for(std::size_t j = 0; j < n_dims; ++j)
}); {
idx[op.axis] = j;
batch_max[i] = std::max(batch_max[i], input(idx.begin(), idx.end()));
}
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); for(std::size_t j = 0; j < n_dims; ++j)
shape_for_each(output_shape, [&](auto idx) { {
auto index = this->compute_batch_index(idx, batch_shape, op.axis); idx[op.axis] = j;
batch_sum[index] += std::exp(output(idx.begin(), idx.end())); std::size_t index = output_shape.index(idx);
}); output[index] = input[index] - batch_max[i];
}
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[op.axis] = j;
batch_sum[i] += std::exp(output(idx.begin(), idx.end()));
}
for(std::size_t i = 0; i < batch_sum.size(); ++i)
{
batch_sum[i] = std::log(batch_sum[i]); batch_sum[i] = std::log(batch_sum[i]);
}
shape_for_each(output_shape, [&](auto idx) { for(std::size_t j = 0; j < n_dims; ++j)
auto index = this->compute_batch_index(idx, batch_shape, op.axis); {
output(idx.begin(), idx.end()) -= batch_sum[index]; idx[op.axis] = j;
output(idx.begin(), idx.end()) -= batch_sum[i];
}
}); });
}); });
......
...@@ -12,9 +12,12 @@ endif() ...@@ -12,9 +12,12 @@ endif()
add_library(migraphx_device add_library(migraphx_device
device/add.cpp device/add.cpp
device/argmax.cpp
device/argmin.cpp
device/max.cpp device/max.cpp
device/min.cpp device/min.cpp
device/exp.cpp device/exp.cpp
device/erf.cpp
device/log.cpp device/log.cpp
device/sin.cpp device/sin.cpp
device/cos.cpp device/cos.cpp
...@@ -36,6 +39,7 @@ add_library(migraphx_device ...@@ -36,6 +39,7 @@ add_library(migraphx_device
device/sub.cpp device/sub.cpp
device/pack.cpp device/pack.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
...@@ -44,6 +48,8 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR ...@@ -44,6 +48,8 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(migraphx_gpu add_library(migraphx_gpu
argmax.cpp
argmin.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
fuse_ops.cpp fuse_ops.cpp
hip.cpp hip.cpp
...@@ -74,6 +80,7 @@ add_library(migraphx_gpu ...@@ -74,6 +80,7 @@ add_library(migraphx_gpu
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
clip.cpp clip.cpp
reduce_sum.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment