Commit 511c8d8f authored by Paul's avatar Paul
Browse files

Merge from develop

parents 9b7c44ab 2a2c146c
...@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x) ...@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x); return std::find(c.begin(), c.end(), x);
} }
struct empty
{
};
} // namespace detail } // namespace detail
template <class C, class T> template <class C, class T>
...@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return std::all_of(c.begin(), c.end(), p); return std::all_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool all_of(detail::empty, const Predicate&)
{
return true;
}
template <class C, class Predicate> template <class C, class Predicate>
bool any_of(const C& c, const Predicate& p) bool any_of(const C& c, const Predicate& p)
{ {
...@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
return std::any_of(c.begin(), c.end(), p); return std::any_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool any_of(detail::empty, const Predicate&)
{
return false;
}
template <class C, class Predicate> template <class C, class Predicate>
bool none_of(const C& c, const Predicate& p) bool none_of(const C& c, const Predicate& p)
{ {
...@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p) ...@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
return std::none_of(c.begin(), c.end(), p); return std::none_of(c.begin(), c.end(), p);
} }
template <class Predicate>
bool none_of(detail::empty, const Predicate&)
{
return true;
}
template <class Range, class Iterator> template <class Range, class Iterator>
void copy(Range&& r, Iterator it) void copy(Range&& r, Iterator it)
{ {
......
...@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs) ...@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs)
}; };
} }
template <class T>
auto visit_all(const std::vector<T>& x)
{
auto&& s = x.front().get_shape();
if(!std::all_of(
x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
std::vector<tensor_view<type>> result;
std::transform(x.begin(), x.end(), std::back_inserter(result), [&](const auto& y) {
return make_view(y.get_shape(), as.from(y.data()));
});
v(result);
});
};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT ...@@ -15,35 +15,18 @@ struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
template <bool B> template <bool B>
using bool_c = std::integral_constant<bool, B>; using bool_c = std::integral_constant<bool, B>;
template <int N> #define MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y) x##y
struct requires_enum #define MIGRAPHX_REQUIRES_CAT(x, y) MIGRAPHX_REQUIRES_PRIMITIVE_CAT(x, y)
{
enum e
{
a = 0
};
};
#define MIGRAPHX_REQUIRES_CAT(x, y) x##y #define MIGRAPHX_REQUIRES_VAR() MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__)
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#else #else
#if 0
// TODO: This currently crashed on clang
#define MIGRAPHX_REQUIRES(...) \ #define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \ bool MIGRAPHX_REQUIRES_VAR() = true, \
PrivateRequires, \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \ int>::type = 0
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -42,7 +42,9 @@ template <class Range> ...@@ -42,7 +42,9 @@ template <class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void()) -> decltype(r.begin(), r.end(), void())
{ {
os << "{";
os << stream_range(r); os << stream_range(r);
os << "}";
} }
template <class T> template <class T>
......
...@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f) ...@@ -52,6 +52,8 @@ inline std::string transform_string(std::string s, F f)
inline std::string to_upper(std::string s) { return transform_string(std::move(s), ::toupper); } inline std::string to_upper(std::string s) { return transform_string(std::move(s), ::toupper); }
inline std::string to_lower(std::string s) { return transform_string(std::move(s), ::tolower); }
inline bool starts_with(const std::string& value, const std::string& prefix) inline bool starts_with(const std::string& value, const std::string& prefix)
{ {
if(prefix.size() > value.size()) if(prefix.size() > value.size())
......
...@@ -19,7 +19,7 @@ rocm_install_targets( ...@@ -19,7 +19,7 @@ rocm_install_targets(
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraphx_onnx) target_link_libraries(read_onnx migraphx_cpu migraphx_onnx)
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
......
...@@ -100,6 +100,7 @@ struct onnx_parser ...@@ -100,6 +100,7 @@ struct onnx_parser
void init_actv_func() void init_actv_func()
{ {
// Support name format of all lower case or the first letter capital
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{})); map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{})); map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{})); map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
...@@ -352,7 +353,8 @@ struct onnx_parser ...@@ -352,7 +353,8 @@ struct onnx_parser
{ {
// insert zeros for pad op (args[0] has 4 dims) // insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}; padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0); l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
l0);
} }
else else
{ {
...@@ -870,7 +872,9 @@ struct onnx_parser ...@@ -870,7 +872,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) { auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
...@@ -961,7 +965,9 @@ struct onnx_parser ...@@ -961,7 +965,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
// need 4 activation functions // need 4 activation functions
...@@ -1088,7 +1094,9 @@ struct onnx_parser ...@@ -1088,7 +1094,9 @@ struct onnx_parser
auto names = attributes.at("activations").strings(); auto names = attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::copy(names.begin(), names.end(), vec_names.begin()); std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
} }
// need 6 activation functions for bidirectional directions // need 6 activation functions for bidirectional directions
......
...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON) ...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON)
C_VISIBILITY_PRESET hidden C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden CXX_VISIBILITY_PRESET hidden
) )
if(MIGRAPHX_ENABLE_TF) target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu)
target_compile_definitions(migraphx_py PRIVATE -DENABLE_TF)
else()
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
endif()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu) target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU) target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
......
...@@ -6,11 +6,9 @@ ...@@ -6,11 +6,9 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#ifdef ENABLE_TF
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#else
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#endif #include <migraphx/type_name.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -104,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -104,8 +102,13 @@ migraphx::shape to_shape(const py::buffer_info& info)
t = as.type_enum(); t = as.type_enum();
n = sizeof(as()); n = sizeof(as());
} }
}); });
if(n == 0)
{
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
}
auto strides = info.strides; auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0; return n > 0 ? i / n : 0;
...@@ -161,16 +164,13 @@ PYBIND11_MODULE(migraphx, m) ...@@ -161,16 +164,13 @@ PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
#ifdef ENABLE_TF
m.def("parse_tf", m.def("parse_tf",
&migraphx::parse_tf, &migraphx::parse_tf,
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true); py::arg("is_nhwc") = true);
#else
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
#endif
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
return migraphx::cpu::target{}; return migraphx::cpu::target{};
......
...@@ -205,16 +205,18 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -205,16 +205,18 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih_lens = sih->get_shape().lens();
// bias // bias
instruction_ref bb{};
if(bias != prog.end()) if(bias != prog.end())
{ {
long hs = r->get_shape().lens()[2]; long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb); auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape().lens()}, b); bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -228,19 +230,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -228,19 +230,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end()) if(bias != prog.end())
{ {
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias); xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
}
else
{
ht = xt_ht;
} }
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
// apply activation function // apply activation function
ht = prog.insert_instruction(ins, actv_func, ht); auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
sih = ht; sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length, // add the dimensions of sequence length (axis 0 for sequence length,
...@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -485,62 +482,41 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]}); migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<int> data(s.elements(), 1); std::vector<float> data(s.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{s, data}); auto l1 = prog.add_literal(migraphx::literal{s, data});
// weight matrix // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);
auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr); auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz); auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr); auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh); auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
size_t bs = ih->get_shape().lens()[1];
// bias // bias
instruction_ref brcst_bz{}; instruction_ref bwb{};
instruction_ref brcst_br{}; instruction_ref brb_zr{};
instruction_ref brcst_wbh{}; instruction_ref brb_h{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto broadcast_lens = sih->get_shape().lens();
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr); auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, br); auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brb_zr = prog.insert_instruction(
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh); ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, broadcast_lens}, bh); brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -549,56 +525,58 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz); auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end()) if(bias != prog.end())
{ {
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz); xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
} }
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr); auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr); auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end()) auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
{ auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
} auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r); auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref xht_h; instruction_ref hr_h{};
if(linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh); hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
}
else
{
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
} }
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); instruction_ref ht1_rh{};
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
if(bias != prog.end()) if(bias != prog.end())
{ {
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh); ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
} }
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); else
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh); ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
} }
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
} }
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
...@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -913,35 +891,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
migraphx::shape r_shape = r->get_shape(); migraphx::shape r_shape = r->get_shape();
long seq_len = static_cast<long>(seq_shape.lens()[0]); long seq_len = static_cast<long>(seq_shape.lens()[0]);
long hs = static_cast<long>(r_shape.lens()[2]); long hs = static_cast<long>(r_shape.lens()[2]);
auto bs = ih->get_shape().lens()[1];
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto wi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw); auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto tran_wi = prog.insert_instruction(ins, op::transpose{perm}, wi);
auto wo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
auto tran_wo = prog.insert_instruction(ins, op::transpose{perm}, wo);
auto wf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
auto tran_wf = prog.insert_instruction(ins, op::transpose{perm}, wf);
auto wc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sw);
auto tran_wc = prog.insert_instruction(ins, op::transpose{perm}, wc);
// r matrix // r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto ri = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr); auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
auto tran_ri = prog.insert_instruction(ins, op::transpose{perm}, ri);
auto ro = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
auto tran_ro = prog.insert_instruction(ins, op::transpose{perm}, ro);
auto rf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto tran_rf = prog.insert_instruction(ins, op::transpose{perm}, rf);
auto rc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sr);
auto tran_rc = prog.insert_instruction(ins, op::transpose{perm}, rc);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
...@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -951,40 +910,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic_lens = sic->get_shape().lens(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref bi_brcst{}; instruction_ref wrb{};
instruction_ref bo_brcst{};
instruction_ref bf_brcst{};
instruction_ref bc_brcst{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi); auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); wrb = prog.insert_instruction(
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias); ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, bc);
} }
// peep hole // peep hole
instruction_ref pphi_brcst{}; instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{}; instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{}; instruction_ref pphf_brcst{};
if(pph != prog.end()) if(pph != prog.end())
{ {
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
...@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1004,44 +946,31 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end()) if(bias != prog.end())
{ {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst); xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
} }
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ft_before_actv =
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
auto ct_before_actv =
prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic); auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
} }
if(bias != prog.end()) auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
{
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
}
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv); auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
if(bias != prog.end())
{
ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst);
}
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct // equation Ct = ft (.) Ct-1 + it (.) ct
...@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1050,19 +979,11 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct); auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt; last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
} }
if(bias != prog.end())
{
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv); auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct) // Ht = ot (.) h(Ct)
......
...@@ -517,40 +517,60 @@ struct cpu_unary ...@@ -517,40 +517,60 @@ struct cpu_unary
} }
}; };
struct softmax2d struct cpu_softmax
{ {
std::string name() const { return "cpu::softmax2d"; } op::softmax op;
shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const template <class Self, class F>
{ static auto reflect(Self& self, F f)
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
auto nb = input.get_shape().lens()[0];
auto nc = input.get_shape().lens()[1];
auto nh = input.get_shape().lens()[2];
auto nw = input.get_shape().lens()[3];
dfor(nb, nh, nw)([&](std::size_t b, std::size_t i, std::size_t j) {
value_type cmax = std::numeric_limits<value_type>::lowest();
for(std::size_t c = 0; c < nc; c++)
{
cmax = std::max(cmax, input(b, c, i, j));
}
for(std::size_t c = 0; c < nc; c++)
{ {
output(b, c, i, j) = std::exp(input(b, c, i, j) - cmax); return migraphx::reflect(self.op, f);
} }
value_type sum = value_type(0);
for(std::size_t c = 0; c < nc; c++) std::string name() const { return "cpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T>
std::size_t compute_batch_index(T idx, shape& batch_shape, int axis) const
{ {
sum += output(b, c, i, j); idx[axis] = 0;
return batch_shape.index(idx);
} }
for(std::size_t c = 0; c < nc; c++)
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
output(b, c, i, j) = output(b, c, i, j) / sum; argument result{output_shape};
} auto batch_lens = output_shape.lens();
batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) =
std::exp(input(idx.begin(), idx.end()) - batch_max[index]);
});
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_sum[index] += output(idx.begin(), idx.end());
});
shape_for_each(output_shape, [&](auto idx) {
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) /= batch_sum[index];
}); });
}); });
return result; return result;
} }
}; };
...@@ -569,33 +589,19 @@ struct cpu_logsoftmax ...@@ -569,33 +589,19 @@ struct cpu_logsoftmax
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <typename T> template <typename T>
std::size_t compute_batch_index(const T& idx, shape& batch_shape, int axis) const std::size_t compute_batch_index(T idx, const shape& batch_shape, int axis) const
{ {
if(axis == 0) idx[axis] = 0;
{ return batch_shape.index(idx);
return 0;
}
else
{
std::vector<std::size_t> batch_idx(idx.begin(), idx.begin() + axis);
return batch_shape.index(batch_idx.begin(), batch_idx.end());
}
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto lens = output_shape.lens(); auto batch_lens = output_shape.lens();
std::vector<std::size_t> batch_lens{}; batch_lens[op.axis] = 1;
if(op.axis == 0) shape batch_shape{shape::int32_type, batch_lens};
{
batch_lens.push_back(1);
}
else
{
batch_lens.insert(batch_lens.begin(), lens.begin(), lens.begin() + op.axis);
}
shape batch_shape{migraphx::shape::uint32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type; using value_type = typename decltype(input)::value_type;
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
...@@ -660,7 +666,7 @@ struct cpu_apply ...@@ -660,7 +666,7 @@ struct cpu_apply
apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>();
apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>(); apply_map["lrn"] = extend_op<cpu_lrn, op::lrn>();
apply_map["pad"] = extend_op<cpu_pad, op::pad>(); apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["softmax"] = simple_op<softmax2d>(); apply_map["softmax"] = extend_op<cpu_softmax, op::softmax>();
} }
void apply() void apply()
......
...@@ -27,6 +27,7 @@ add_library(migraphx_device ...@@ -27,6 +27,7 @@ add_library(migraphx_device
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
device/logsoftmax.cpp device/logsoftmax.cpp
device/softmax.cpp
device/convert.cpp device/convert.cpp
device/mul.cpp device/mul.cpp
device/concat.cpp device/concat.cpp
......
...@@ -10,22 +10,20 @@ namespace gpu { ...@@ -10,22 +10,20 @@ namespace gpu {
namespace device { namespace device {
argument concat(hipStream_t stream, argument concat(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape&,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::vector<std::size_t> offsets) std::vector<std::size_t> offsets)
{ {
for(std::size_t l = 0; l < args.size() - 1; l++) auto ninputs = args.size() - 1;
for(std::size_t j = 0; j < ninputs; j++)
{ {
auto argl = args[l]; auto&& arg = args[j];
std::size_t nelements = argl.get_shape().elements(); std::size_t nelements = arg.get_shape().elements();
visit_all(args.back(), argl)([&](auto output, auto input) { auto offset = offsets[j];
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { hip_visit_all(args.back(), arg)([&](auto output, auto input) {
auto* outptr = output.data() + offsets[l]; gs_launch(stream, nelements)([=](auto i) {
const auto* inptr = input.data(); auto idx = output.get_shape().index(input.get_shape().multi(i));
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); output.data()[idx + offset] = input.data()[i];
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)(
[=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
}); });
}); });
} }
......
...@@ -11,35 +11,30 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,35 +11,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream, argument result, argument arg1, argument arg2, int axis)
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
{ {
auto axis_index = (axis < 0) ? (axis + args[0].get_shape().lens().size()) : axis; auto axis_index = (axis < 0) ? (axis + arg1.get_shape().lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { auto& input_shape = arg1.get_shape();
std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) {
const auto* indices_ptr = device_cast(indices.data());
auto* out_ptr = device_cast(output.data());
const auto* in_ptr = device_cast(input.data());
auto& input_shape = args[0].get_shape();
auto lens = input_shape.lens(); auto lens = input_shape.lens();
lens[axis_index] = args[1].get_shape().elements(); lens[axis_index] = arg2.get_shape().elements();
migraphx::shape out_comp_shape{output_shape.type(), lens}; shape out_comp_shape{result.get_shape().type(), lens};
visit_tensor_size(out_comp_shape.lens().size(), [&](auto n_out_dim) { std::size_t nelements = result.get_shape().elements();
hip_tensor_descriptor<n_out_dim> desc_input(input_shape);
hip_tensor_descriptor<n_out_dim> desc_output(out_comp_shape); visit_all(result, arg1)([&](auto output, auto input_v) {
gs_launch(stream, nelements)([=](auto ii) { hip_visit_views(input_v, out_comp_shape)([&](auto input, auto out_comp) {
auto in_idx = desc_output.multi(ii); arg2.visit([&](auto indices) {
in_idx[axis_index] = indices_ptr[in_idx[axis_index]]; const auto* indices_ptr = device_cast(indices.data());
out_ptr[ii] = in_ptr[desc_input.linear(in_idx)]; auto* output_ptr = device_cast(output.data());
gs_launch(stream, nelements)([=](auto i) {
auto idx = out_comp.multi(i);
idx[axis_index] = indices_ptr[idx[axis_index]];
output_ptr[i] = input[idx];
}); });
}); });
}); });
}); });
return args.back(); return result;
} }
} // namespace device } // namespace device
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARRAY_HPP
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, std::size_t N>
struct hip_array
{
T d[N];
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR std::integral_constant<std::size_t, N> size() const { return {}; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR T dot(const hip_array& x) const
{
T result = 0;
for(std::size_t i = 0; i < N; i++)
result += x[i] * d[i];
return result;
}
MIGRAPHX_DEVICE_CONSTEXPR T product() const
{
T result = 1;
for(std::size_t i = 0; i < N; i++)
result *= d[i];
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y)
{
hip_array result;
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] * y[i];
return result;
}
};
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP #define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/visit.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/array.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -13,57 +13,30 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,57 +13,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T>
using vec4 = T __attribute__((ext_vector_type(4)));
template <class T>
__device__ __host__ vec4<T>* as_vec4(T* x)
{
return reinterpret_cast<vec4<T>*>(x);
}
template <class T>
__device__ __host__ T* as_pointer(vec4<T>* x)
{
return reinterpret_cast<T*>(x);
}
template <class... Ts> template <class... Ts>
auto pack_vec4(Ts... xs) auto pack(Ts... xs) __device__
{ {
return [=](auto f, std::size_t n) { return f(as_vec4(xs)[n]...); }; return [=](auto f) { return f(xs...); };
} }
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args) auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
const auto& output_shape = result.get_shape(); std::size_t nelements = result.get_shape().elements();
visit_all(result, args...)([&](auto output, auto... inputs) { hip_visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { gs_launch(stream, nelements)([=](auto i) {
auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, auto idx = output.get_shape().multi(i);
device_cast(inputs.data()))...); output[i] = f(inputs[idx]...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) {
data([&](auto&&... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
}); });
}); });
} }
template <class F> template <class F, class... Arguments>
void trinary_broadcast_vec_impl(hipStream_t stream, void nary_broadcast_vec_impl(
F f, hipStream_t stream, F f, argument result, argument barg, Arguments... args)
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape(); const auto& b_shape = barg.get_shape();
auto bdim = auto bdim =
std::distance(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
...@@ -73,156 +46,45 @@ void trinary_broadcast_vec_impl(hipStream_t stream, ...@@ -73,156 +46,45 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data()));
auto* zp = as_vec4(device_cast(input3.data()));
auto* outp = as_vec4(device_cast(output.data()));
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)(
[&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> y = yp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(x[j], y[j], b);
}
outp[i] = out;
}
});
});
}
template <class F>
void trinary_broadcast_impl(hipStream_t stream,
F f,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) { MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = device_cast(input1.data());
auto* yp = device_cast(input2.data());
auto* zp = device_cast(input3.data());
auto* outp = device_cast(output.data());
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
type x = xp[i];
type y = yp[i];
outp[i] = f(x, y, b);
}
});
});
}
template <class F>
void binary_broadcast_vec_impl(
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data()));
auto* outp = as_vec4(device_cast(output.data()));
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{ {
buffer[i] = yp[i]; buffer[i] = binput.data()[i];
} }
__syncthreads(); __syncthreads();
auto* bp = as_pointer(buffer); auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx]; auto b = bp[bidx];
vec4<type> x = xp[i]; auto out = output.data()[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], b); out[j] = f(inputs.data()[i][j]..., b);
} }
outp[i] = out; output.data()[i] = out;
} }
}); });
}); });
} }
template <class F> template <class F, class... Arguments>
void binary_broadcast_impl( void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = barg.get_shape();
auto bdim = auto bdim =
std::distance(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) { std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
...@@ -232,31 +94,25 @@ void binary_broadcast_impl( ...@@ -232,31 +94,25 @@ void binary_broadcast_impl(
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = device_cast(input1.data());
auto* yp = device_cast(input2.data());
auto* outp = device_cast(output.data());
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size(); std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048]; MIGRAPHX_DEVICE_SHARED type buffer[2048];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len; i += nlocal)
{ {
buffer[i] = yp[i]; buffer[i] = binput.data()[i];
} }
__syncthreads(); __syncthreads();
// Process the data // Process the data
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
auto bidx = (i % bdim_next_stride) / bdim_stride; auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx]; auto b = buffer[bidx];
type x = xp[i]; output.data()[i] = f(inputs.data()[i]..., b);
outp[i] = f(x, b);
} }
}); });
}); });
...@@ -265,15 +121,14 @@ void binary_broadcast_impl( ...@@ -265,15 +121,14 @@ void binary_broadcast_impl(
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(device_cast(inputs.data())...); auto data = pack_vec<4>(device_cast(inputs.data())...);
auto* outp = as_vec4(device_cast(output.data())); auto* outp = as_vec<4>(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) { gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i]; vec<type, 4> out = outp[i];
data( data(
[&](auto... xs) { [&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
...@@ -290,13 +145,9 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -290,13 +145,9 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements()); std::size_t nelements = result.get_shape().elements();
const auto& output_shape = result.get_shape(); hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) {
visit_all(result, args...)([&](auto output, auto... inputs) { gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); });
auto data = pack(device_cast(inputs.data())...);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
} }
...@@ -313,12 +164,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args) ...@@ -313,12 +164,6 @@ void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
nary_nonstandard_impl(stream, f, result, args...); nary_nonstandard_impl(stream, f, result, args...);
} }
template <class F>
void nary_impl(hipStream_t stream, F f, argument result)
{
nary_standard_impl(stream, f, result);
}
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args) auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args)
{ {
...@@ -332,71 +177,50 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -332,71 +177,50 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
} }
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args) auto nary(hipStream_t stream, argument result)
{ {
return [=](auto f) { nary_impl(stream, f, result, args...); }; return [=](auto f) { nary_standard_impl(stream, f, result); };
} }
inline auto template <class... Arguments>
nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) auto nary(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) {
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and
not arg2.get_shape().scalar())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg2.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
binary_broadcast_vec_impl(stream, f, result, arg1, arg2);
else
binary_broadcast_impl(stream, f, result, arg1, arg2);
return;
}
}
nary_impl(stream, f, result, arg1, arg2);
};
}
inline auto nary(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same auto barg = back_args(args...);
if(arg1.get_shape().standard() and arg2.get_shape().standard() and bool fallback = pop_back_args(args...)([&](auto&&... args2) {
arg3.get_shape().broadcasted()) auto bshape = barg.get_shape();
const bool standard =
all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
const bool same_shapes = all_of(
{args2.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
// TODO: Check result and args shape is the same
if(standard and same_shapes and bshape.broadcasted() and not bshape.scalar())
{ {
auto not_zero = [](auto x) { return x != 0; }; auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg3.get_shape().strides(); const auto& strides = bshape.strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero); auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it); auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx]; auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx]; auto b_stride = result.get_shape().strides()[b_idx];
assert(arg3.get_shape().lens()[b_idx] == b_len); assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero)) if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0); const bool divisible_by_4 =
(b_len % 4 == 0) and (b_stride % 4 == 0) and
(front_args(args...).get_shape().elements() % 4 == 0);
if(divisible_by_4) if(divisible_by_4)
trinary_broadcast_vec_impl(stream, f, result, arg1, arg2, arg3); nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else else
trinary_broadcast_impl(stream, f, result, arg1, arg2, arg3); nary_broadcast_impl(stream, f, result, barg, args2...);
return; return false;
} }
} }
nary_impl(stream, f, result, arg1, arg2, arg3); return true;
});
if(fallback)
nary_impl(stream, f, result, args...);
}; };
} }
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SHAPE_HPP
#include <migraphx/gpu/device/array.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <std::size_t N>
struct hip_shape
{
using hip_index = hip_array<std::size_t, N>;
hip_array<std::size_t, N> lens = {};
hip_array<std::size_t, N> strides = {};
bool standard = false;
__device__ __host__ hip_shape() = default;
hip_shape(const shape& s) : standard(s.standard())
{
assert(s.lens().size() == N);
assert(s.strides().size() == N);
std::copy(s.lens().begin(), s.lens().end(), lens.begin());
std::copy(s.strides().begin(), s.strides().end(), strides.begin());
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const { return lens.product(); }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const { return x.dot(strides); }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const
{
std::size_t idx = 0;
for(std::size_t i = 0; i < x.size(); i++)
idx += *(x.begin() + i) * strides[i];
return idx;
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::size_t i) const
{
if(this->standard)
return i;
else
{
const std::size_t rank = this->lens.size();
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0; j < this->lens.size(); j++)
{
const std::size_t k = rank - j - 1;
const std::size_t stride = this->strides[k];
const std::size_t len = this->lens[k];
const std::size_t slen = s * len;
const std::size_t idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
}
}
MIGRAPHX_DEVICE_CONSTEXPR hip_index multi(std::size_t idx) const
{
hip_index result;
std::size_t tidx = idx;
for(std::size_t is = 0; is < result.size(); is++)
{
result[is] = tidx / strides[is];
tidx = tidx % strides[is];
}
return result;
}
};
template <std::size_t N>
hip_shape<N> make_hip_shape(const shape& x)
{
return x;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP #define MIGRAPHX_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#include <hip/hip_runtime.h> #include <migraphx/gpu/device/visit.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class F> template <std::size_t NDim>
void visit_tensor_size(std::size_t n, F f) using hip_tensor_index = hip_array<std::size_t, NDim>;
{
switch(n)
{
case 1:
{
f(std::integral_constant<std::size_t, 1>{});
break;
}
case 2:
{
f(std::integral_constant<std::size_t, 2>{});
break;
}
case 3:
{
f(std::integral_constant<std::size_t, 3>{});
break;
}
case 4:
{
f(std::integral_constant<std::size_t, 4>{});
break;
}
case 5:
{
f(std::integral_constant<std::size_t, 5>{});
break;
}
default: throw std::runtime_error("Unknown tensor size");
}
}
template <size_t NDim>
struct hip_index
{
size_t d[NDim];
__device__ __host__ size_t& operator[](size_t i) { return d[i]; }
__device__ __host__ size_t operator[](size_t i) const { return d[i]; }
};
template <size_t NDim> template <std::size_t NDim>
struct hip_tensor_descriptor struct hip_tensor_descriptor
{ {
__device__ __host__ hip_tensor_descriptor() = default; __device__ __host__ hip_tensor_descriptor() = default;
...@@ -63,26 +22,26 @@ struct hip_tensor_descriptor ...@@ -63,26 +22,26 @@ struct hip_tensor_descriptor
std::copy(s.strides().begin(), s.strides().end(), strides); std::copy(s.strides().begin(), s.strides().end(), strides);
} }
__device__ __host__ hip_index<NDim> multi(size_t idx) const __device__ __host__ hip_tensor_index<NDim> multi(std::size_t idx) const
{ {
hip_index<NDim> result{}; hip_tensor_index<NDim> result{};
size_t tidx = idx; std::size_t tidx = idx;
for(size_t is = 0; is < NDim; is++) for(std::size_t is = 0; is < NDim; is++)
{ {
result[is] = tidx / strides[is]; result[is] = tidx / strides[is];
tidx = tidx % strides[is]; tidx = tidx % strides[is];
} }
return result; return result;
} }
__device__ __host__ size_t linear(hip_index<NDim> s) const __device__ __host__ std::size_t linear(hip_tensor_index<NDim> s) const
{ {
size_t idx = 0; std::size_t idx = 0;
for(size_t i = 0; i < NDim; i++) for(std::size_t i = 0; i < NDim; i++)
idx += s[i] * strides[i]; idx += s[i] * strides[i];
return idx; return idx;
} }
size_t lens[NDim] = {}; std::size_t lens[NDim] = {};
size_t strides[NDim] = {}; std::size_t strides[NDim] = {};
}; };
} // namespace device } // namespace device
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_TENSOR_VIEW_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_TENSOR_VIEW_HPP
#include <migraphx/gpu/device/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, std::size_t N>
struct hip_tensor_view
{
using value_type = T;
using hip_index = typename hip_shape<N>::hip_index;
__device__ __host__ hip_tensor_view() = default;
__host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
__host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements(); }
MIGRAPHX_DEVICE_CONSTEXPR value_type* data() const { return d; }
template <class U>
MIGRAPHX_DEVICE_CONSTEXPR value_type& operator[](U i) const
{
return d[s.index(i)];
}
MIGRAPHX_DEVICE_CONSTEXPR value_type* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR value_type* end() const { return d + size(); }
private:
value_type* d = nullptr;
hip_shape<N> s{};
};
template <std::size_t N, class T>
hip_tensor_view<T, N> make_hip_view(const shape& s, T* x)
{
return {x, s};
}
template <std::size_t N, class T>
hip_tensor_view<T, N> make_hip_view(tensor_view<T> x)
{
return {x};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -8,14 +8,45 @@ ...@@ -8,14 +8,45 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#include <hip/hip_runtime.h>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/tensor_view.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
template <class T, std::size_t N>
using vec = T __attribute__((ext_vector_type(N)));
template <std::size_t N, class T>
__device__ __host__ T* as_pointer(vec<T, N>* x)
{
return reinterpret_cast<T*>(x);
}
template <std::size_t N, class T>
__device__ __host__ vec<T, N>* as_vec(T* x)
{
return reinterpret_cast<vec<T, N>*>(x);
}
template <std::size_t N, class T>
tensor_view<vec<T, N>> as_vec(tensor_view<T> x)
{
return {x.get_shape(), as_vec<N>(x.data())};
}
template <std::size_t N, class... Ts>
auto pack_vec(Ts... xs)
{
return [=](auto f, std::size_t n) { return f(as_vec<N>(xs)[n]...); };
}
using gpu_half = __fp16; using gpu_half = __fp16;
namespace detail { namespace detail {
...@@ -25,6 +56,12 @@ struct device_type ...@@ -25,6 +56,12 @@ struct device_type
using type = T; using type = T;
}; };
template <class T, std::size_t N>
struct device_type<vec<T, N>>
{
using type = vec<typename device_type<T>::type, N>;
};
template <> template <>
struct device_type<half> struct device_type<half>
{ {
...@@ -38,7 +75,7 @@ struct host_type ...@@ -38,7 +75,7 @@ struct host_type
}; };
template <> template <>
struct device_type<gpu_half> struct host_type<gpu_half>
{ {
using type = half; using type = half;
}; };
...@@ -64,9 +101,9 @@ host_type<T>* host_cast(T* x) ...@@ -64,9 +101,9 @@ host_type<T>* host_cast(T* x)
} }
template <class T> template <class T>
device_type<T> device_cast(T x) device_type<T> device_cast(const T& x)
{ {
return reinterpret_cast<device_type<T>>(x); return reinterpret_cast<const device_type<T>&>(x);
} }
template <class T> template <class T>
...@@ -75,6 +112,12 @@ device_type<T>* device_cast(T* x) ...@@ -75,6 +112,12 @@ device_type<T>* device_cast(T* x)
return reinterpret_cast<device_type<T>*>(x); return reinterpret_cast<device_type<T>*>(x);
} }
template <class T>
tensor_view<device_type<T>> device_cast(tensor_view<T> x)
{
return {x.get_shape(), reinterpret_cast<device_type<T>*>(x.data())};
}
template <class T> template <class T>
T to_hip_type(T x) T to_hip_type(T x)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment