Commit 2c603654 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into keep_std_shape

parents 36947a39 2a79a9ff
......@@ -142,10 +142,12 @@ jobs:
with:
python-version: 3.6
- name: Install pyflakes
run: pip install pyflakes==2.3.1
run: pip install pyflakes==2.3.1 mypy==0.931
- name: Run pyflakes
run: pyflakes examples/ tools/ src/ test/ doc/
run: |
pyflakes examples/ tools/ src/ test/ doc/
mypy tools/api.py
linux:
......
......@@ -200,6 +200,7 @@ rocm_enable_cppcheck(
RULE_FILE
${CMAKE_CURRENT_SOURCE_DIR}/cppcheck.rules
SOURCES
examples/
src/
test/
INCLUDE
......
......@@ -24,7 +24,6 @@ int main(int argc, char** argv)
return 0;
}
char* parse_arg = getCmdOption(argv + 2, argv + argc, "--parse");
char* load_arg = getCmdOption(argv + 2, argv + argc, "--load");
char* save_arg = getCmdOption(argv + 2, argv + argc, "--save");
const char* input_file = argv[1];
......
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp>
......@@ -51,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m)
cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
......@@ -63,11 +65,30 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str
return *this;
}
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
this->return_type = "auto";
return *this;
}
struct cpp_generator_impl
{
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
......@@ -83,41 +104,54 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
}
std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args)
{
auto v = op.to_value();
auto attributes = op.attributes();
if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
return interpolate_string(attributes["point_op"].to<std::string>(),
[&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
auto v = op.to_value();
std::string code;
if(contains(impl->point_op_map, op.name()))
{
code = impl->point_op_map.at(op.name());
}
else
{
auto attributes = op.attributes();
if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
code = attributes["point_op"].to<std::string>();
}
return interpolate_string(code, [&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
}
std::string cpp_generator::str() const { return impl->fs.str(); }
......@@ -148,6 +182,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
std::string cpp_generator::create_function(const cpp_generator::function& f)
{
impl->function_count++;
if(not f.tparams.empty())
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '(';
......
......@@ -34,6 +34,7 @@ struct cpp_generator
std::string return_type = "void";
std::string name = "";
std::vector<std::string> attributes = {};
std::vector<std::string> tparams = {};
function& set_body(const module& m, const generate_module_callback& g);
function& set_body(const std::string& s)
{
......@@ -52,6 +53,7 @@ struct cpp_generator
}
function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
};
cpp_generator();
......@@ -66,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
std::string str() const;
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto num_data = args.size();
if(num_data == 1)
return args[0];
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
return std::accumulate(args.begin(), args.end(), args[0], [&](auto& mean, auto& data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation
data_i = info.add_broadcastable_binary_op("div", data_i, divisor);
if(data_i != args[0])
return info.add_broadcastable_binary_op("add", mean, data_i);
return data_i;
});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -163,9 +163,9 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
struct parse_resize : op_parser<parse_resize>
{
std::vector<op_desc> operators() const { return {{"Resize"}}; }
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/,
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
......@@ -183,7 +183,7 @@ struct parse_resize : op_parser<parse_resize>
if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1)
{
MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!");
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": exclude_outside 1 is not supported!");
}
// input data shape info
......@@ -215,12 +215,14 @@ struct parse_resize : op_parser<parse_resize>
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!");
check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size())
{
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size");
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
}
// compute the scale
......@@ -239,12 +241,14 @@ struct parse_resize : op_parser<parse_resize>
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_RESIZE: dynamic input scale is not supported!");
"PARSE_" + opd.op_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!");
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
}
std::transform(in_lens.begin(),
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_upsample : op_parser<parse_upsample>
{
std::vector<op_desc> operators() const { return {{"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: only nearest mode is supported!");
}
}
auto arg_scale = args[1]->eval();
check_arg_empty(arg_scale, "PARSE_UPSAMPLE: only constant scale is supported!");
std::vector<float> vec_scale;
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
}
std::vector<std::size_t> out_lens(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::vector<float> idx_scale(in_lens.size());
std::transform(
out_lens.begin(),
out_lens.end(),
in_lens.begin(),
idx_scale.begin(),
[](auto od, auto id) { return (od == id) ? 1.0f : (id - 1.0f) / (od - 1.0f); });
shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
std::transform(idx.begin(),
idx.end(),
idx_scale.begin(),
in_idx.begin(),
// nearest mode
[](auto index, auto scale) {
return static_cast<std::size_t>(std::round(index * scale));
});
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -114,6 +114,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
if(cos.size() != 1)
......
......@@ -63,8 +63,16 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"}));
return compile_pointwise((ctx), inputs, "&" + name, g.str());
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
}
} // namespace gpu
......
......@@ -5,6 +5,9 @@
namespace migraphx {
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
// Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE
......@@ -14,19 +17,67 @@ namespace migraphx {
#define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif
namespace debug {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
template <size_t N>
struct print_buffer
{
char buffer[N + 1] = {0};
char* pos = buffer;
constexpr void append(char c)
{
if(c == 0)
return;
if(pos < buffer + N)
{
*pos = c;
pos++;
}
}
template <size_t M>
constexpr void append(const char (&array)[M])
{
for(int i = 0; i < M; i++)
append(array[i]);
}
};
template <class... Ts>
__host__ __device__ void print(const Ts&... xs)
{
const auto size = (sizeof(xs) + ...);
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer);
}
} // namespace debug
// noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function)
assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& function)
{
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion);
// printf is broken on hip with more than one argument, so use a simple print functions instead
debug::print(file, ":", line, ": ", function, ": assertion '", assertion, "' failed.\n");
// printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort();
}
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto... xs) { \
assert_fail(xs...); \
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__))
#define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
#else
#define MIGRAPHX_ASSERT(cond)
#endif
......
......@@ -137,12 +137,48 @@ constexpr void each_args(F)
{
}
template <class F, class T>
constexpr auto fold_impl(F&&, T&& x)
{
return static_cast<T&&>(x);
}
template <class F, class T, class U, class... Ts>
constexpr auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{
return fold_impl(f, f(static_cast<T&&>(x), static_cast<U&&>(y)), static_cast<Ts&&>(xs)...);
}
template <class F>
constexpr auto fold(F f)
{
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
}
template <class... Ts>
auto pack(Ts... xs)
constexpr auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{
return p1([&](auto... xs) {
return p2([&](auto... ys) {
auto c = [&](auto x, auto y) -> int {
if(compare(x, y))
return 1;
else if(compare(y, x))
return -1;
else
return 0;
};
return fold([](auto x, auto y) { return x ? x : y; })(c(xs, ys)..., 0);
});
});
}
template <index_int N>
constexpr auto arg_c()
{
......@@ -187,7 +223,7 @@ constexpr auto transform_args(F f, Fs... fs)
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
([](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
......@@ -13,6 +13,7 @@ struct integral_constant
using type = integral_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
static constexpr type to() { return {}; }
};
// NOLINTNEXTLINE
......
......@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
......@@ -19,19 +20,30 @@ constexpr T as_float(T x)
} // namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts> \
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts> \
auto __device__ name(type x, Ts... xs) MIGRAPHX_RETURNS(fname(x, xs...))
#define MIGRAPHX_DEVICE_MATH_VEC(name) \
template <class... Ts, MIGRAPHX_REQUIRES(is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) \
{ \
return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts> \
auto __device__ name(migraphx::half x, Ts... xs) \
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_DEVICE_MATH(abs, ::abs)
......@@ -99,21 +111,51 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
template <class T, class U>
constexpr auto& max(const T& a, const U& b)
constexpr auto where(bool cond, const T& a, const U& b)
{
return (a < b) ? b : a;
return cond ? a : b;
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
MIGRAPHX_DEVICE_MATH_VEC(asin)
MIGRAPHX_DEVICE_MATH_VEC(asinh)
MIGRAPHX_DEVICE_MATH_VEC(atan)
MIGRAPHX_DEVICE_MATH_VEC(atanh)
MIGRAPHX_DEVICE_MATH_VEC(ceil)
MIGRAPHX_DEVICE_MATH_VEC(cos)
MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin)
MIGRAPHX_DEVICE_MATH_VEC(sinh)
MIGRAPHX_DEVICE_MATH_VEC(sqrt)
MIGRAPHX_DEVICE_MATH_VEC(tan)
MIGRAPHX_DEVICE_MATH_VEC(tanh)
MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto& min(const T& a, const U& b)
constexpr auto max(const T& a, const U& b)
{
return (a > b) ? b : a;
return where(a < b, b, a);
}
template <class T, class U>
constexpr T convert(U x)
constexpr auto min(const T& a, const U& b)
{
return x;
return where(a > b, b, a);
}
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) -> T { return x; });
}
} // namespace migraphx
......
......@@ -10,13 +10,38 @@
namespace migraphx {
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = f(ps[multi_idx]...);
out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
});
});
}
......@@ -24,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps)
{
auto t = transform_args(make_tensors(), rotate_last());
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize());
t(ps...)([&](auto... xs) {
auto idx = make_index();
pointwise_tensor(idx, f, xs...);
......
......@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
}
template <class T, class... Shapes>
constexpr index_int compute_preload_size(Shapes...)
constexpr index_int compute_preload_size_c(Shapes...)
{
index_int size = 0;
traverse_preload<T>(Shapes{}...)(
......@@ -37,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...)
return size;
}
template <class T, class... Shapes>
constexpr auto compute_preload_size(Shapes...)
{
return _c<compute_preload_size_c<T>(Shapes{}...)>;
}
template <class F, class T, class... Ts>
__device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{
......@@ -48,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
[&](auto x, auto offset, auto copy) {
if constexpr(copy)
{
auto v = vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
if constexpr(decltype(tensor_vec_size(x)){} == 0)
{
auto v = vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
}
else
{
auto b = as_vec(tensor_vec_size(x), buffer + offset);
idx.local_stride(x.get_shape().element_space(),
[&](auto i) { b[i] = x.data()[i]; });
return x.with(b);
}
}
else
{
......@@ -78,7 +94,7 @@ template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs)
{
using type = typename remove_vec<T>::type;
constexpr auto size = compute_preload_size<type>(xs.get_shape()...);
constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){};
const index_int max_size = 512 * sizeof(type);
return [=](auto f) {
if constexpr(size > 0 and size < max_size)
......
......@@ -3,6 +3,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx {
......@@ -24,6 +25,38 @@ constexpr auto vec_size()
return decltype(vec_size(T{})){};
}
template <class... Ts>
constexpr auto is_any_vec()
{
if constexpr(sizeof...(Ts) == 0)
return false_type{};
else
return bool_constant<((vec_size<Ts>() + ...) > 0)>{};
}
template <class T, class I>
constexpr auto vec_at(T x, I i)
{
if constexpr(vec_size<T>() == 0)
return x;
else
{
MIGRAPHX_ASSERT(i < vec_size<T>());
return x[i];
}
}
template <class... Ts>
constexpr auto common_vec_size()
{
return fold([](auto x, auto y) {
if constexpr(x > y)
return x;
else
return y;
})(vec_size<Ts>()...);
}
template <index_int N, class T>
__device__ __host__ auto as_vec(T* x)
{
......@@ -33,5 +66,25 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x);
}
template <class... Ts>
constexpr auto vec_transform(Ts... xs)
{
return [=](auto f) {
if constexpr(is_any_vec<Ts...>())
{
using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0};
for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...);
return result;
}
else
{
return f(xs...);
}
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
......@@ -7,40 +7,70 @@
namespace migraphx {
template <class T>
constexpr auto tensor_vec_size(T)
constexpr auto tensor_vec_size()
{
return vec_size<typename T::type>();
}
template <index_int N, class Shape>
constexpr auto as_vec_shape(Shape s)
template <class T>
constexpr auto tensor_vec_size(T)
{
auto lens = transform(s.lens, s.strides, [](auto len, auto stride) {
if(stride == 1)
return len / N;
else
return len;
});
auto strides = transform(s.strides, [](auto stride) {
if(stride == 1)
return stride;
return stride / N;
return tensor_vec_size<T>();
}
template <index_int N, class Shape, class Axis>
constexpr auto shape_step(Shape s, Axis)
{
static_assert(N > 0, "Vector size must be non-zero");
return sequence(s.lens.size(), [&](auto... is) {
auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis::to();
MIGRAPHX_ASSERT(i != 0);
MIGRAPHX_ASSERT(j != axis or i % N == 0);
if(j == axis)
return i / N;
else
return i;
});
auto strides = transform(s.strides, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis::to();
// If stride of the axis is zero then we dont need to adjust the other strides
if(Shape{}.strides[axis] == 0)
return i;
MIGRAPHX_ASSERT(j == axis or i % N == 0);
if(j == axis)
return i;
else
return i / N;
});
MIGRAPHX_ASSERT(make_shape(lens, strides).elements() * N == s.elements());
MIGRAPHX_ASSERT(strides[Axis{}] == 0 or
make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
});
MIGRAPHX_ASSERT(make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
}
template <index_int N, class T>
__device__ __host__ auto as_vec(T x)
// Bools can not be used as a vector type so convert it to int8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); }
template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis)
{
if constexpr(N == 0)
return x;
else
return make_tensor_view(as_vec<N>(x.data()), as_vec_shape<N>(x.get_shape()));
return make_tensor_view(as_vec<N>(remove_bool(x.data())),
shape_step<N>(x.get_shape(), axis));
}
template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis)
constexpr auto tensor_step(T x, Axis axis)
{
if constexpr(N == 0)
{
......@@ -49,17 +79,8 @@ constexpr auto tensor_step(T x, Axis)
else
{
constexpr auto s = decltype(x.get_shape()){};
MIGRAPHX_ASSERT(s.strides[Axis{}] == 0);
return sequence(x.get_shape().lens.size(), [&](auto... is) {
auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis{};
if(j == axis)
return i / N;
else
return i;
});
return make_tensor_view(x.data(), make_shape(lens, s.strides));
});
MIGRAPHX_ASSERT(s.strides[axis] == 0);
return make_tensor_view(x.data(), shape_step<N>(s, axis));
}
}
......@@ -69,45 +90,71 @@ __device__ __host__ auto as_vec(IntegralConstant ic, T&& x)
return as_vec<ic>(x);
}
template <class... Shapes>
constexpr index_int find_vector_axis(Shapes... ss)
template <class Shape>
constexpr index_int find_vector_axis_c(Shape s)
{
// Find the fastest axis that is not broadcasted
index_int axis = 0;
bool b = false;
for(index_int i = 1; i < s.lens.size(); i++)
{
if(s.strides[i] == 0)
continue;
if(s.strides[axis] == 0 or
pack_compare(less{}, pack(s.strides[i], s.lens[i]), pack(s.strides[axis], s.lens[axis])))
axis = i;
}
return axis;
}
template <class... Shapes>
constexpr index_int find_vector_axis_c(Shapes... ss)
{
const bool all_broadcasted = (ss.broadcasted() and ...);
index_int axis = 0;
bool b = false;
by([&](auto s) {
if(b)
return;
auto it = find(s.strides.begin(), s.strides.end(), 1);
if(it == s.strides.end())
// Skip broadcasted shapes if there are shapes not broadcasted
if(not all_broadcasted and s.broadcasted())
return;
axis = it - s.strides.begin();
b = true;
axis = find_vector_axis_c(s);
if(s.strides[axis] == 1)
b = true;
})(ss...);
if(not b)
return -1;
return axis;
}
template <class... Shapes>
constexpr auto find_vector_axis(Shapes...)
{
return _c<find_vector_axis_c(Shapes{}...)>;
}
template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis axis, Shapes... ss)
constexpr auto is_vectorizable_c(Axis axis, Shapes... ss)
{
return (((ss.lens[axis] % N) == 0 and ss.strides[axis] == 1) and ...);
return ((axis < ss.lens.size() and ss.lens[axis] % N == 0 and
// Only vectorize broadcasted types with stride 0, since this causes issues in the
// preloader
((not ss.broadcasted() and ss.strides[axis] == 1) or ss.strides[axis] == 0)) and
...);
}
template <index_int N, class Shape>
constexpr bool is_vectorizable(Shape s)
template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis, Shapes...)
{
auto it = find(s.strides.begin(), s.strides.end(), 1);
if(it == s.strides.end())
return false;
auto axis = it - s.strides.begin();
return (s.lens[axis] % N) == 0 and s.strides[axis] == 1;
return _c<is_vectorizable_c<N>(Axis::to(), Shapes{}...)>;
}
template <class P>
constexpr auto find_vectorize_size(P pred)
{
if constexpr(pred(_c<4>))
if constexpr(decltype(pred(_c<4>)){})
return _c<4>;
else if constexpr(pred(_c<2>))
else if constexpr(decltype(pred(_c<2>)){})
return _c<2>;
else
return _c<0>;
......@@ -116,11 +163,12 @@ constexpr auto find_vectorize_size(P pred)
template <class T>
__host__ __device__ auto vectorize(T x)
{
if constexpr(vec_size<T>() == 0)
if constexpr(tensor_vec_size<T>() == 0)
{
constexpr auto axis = find_vector_axis(x.get_shape());
constexpr auto n =
find_vectorize_size([&](auto i) { return _c<is_vectorizable<i>(x.get_shape())>; });
return as_vec<n>(x);
find_vectorize_size([&](auto i) { return is_vectorizable<i>(axis, x.get_shape()); });
return as_vec<n>(x, axis);
}
else
{
......@@ -128,34 +176,46 @@ __host__ __device__ auto vectorize(T x)
}
}
template <class F, class... Ts>
inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = decltype(find_vector_axis(xs.get_shape()...)){};
constexpr auto n = find_vectorize_size(
[&](auto i) { return is_vectorizable<i>(axis, xs.get_shape()...); });
by(
[&](auto x) {
constexpr auto s = decltype(x.get_shape()){};
if constexpr(axis < s.strides.size())
{
MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1);
MIGRAPHX_ASSERT(s.lens[axis] > 0);
MIGRAPHX_ASSERT(n == 0 or s.lens[axis] % n == 0);
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x, axis);
}
else
{
return x;
}
},
f)(xs...);
}
else
{
f(xs...);
}
}
inline __device__ __host__ auto auto_vectorize()
{
return [](auto... xs) {
return [=](auto f) {
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = find_vector_axis(xs.get_shape()...);
constexpr auto n = find_vectorize_size(
[&](auto i) { return _c<is_vectorizable<i>(axis, xs.get_shape()...)>; });
by(
[&](auto x) {
constexpr auto s = x.get_shape();
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x);
},
f)(xs...);
}
else
{
f(xs...);
}
};
};
return [](auto... xs) { return [=](auto f) { auto_vectorize_impl(f, xs...); }; };
}
} // namespace migraphx
......
......@@ -2762,6 +2762,80 @@ def maxpool_same_upper_test():
return ([node], [x], [y])
@onnx_test
def mean_broadcast_test():
data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4])
data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT,
[1, 2, 3, 4])
data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [4])
data_3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1])
data_4 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [2, 3, 1])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT,
[1, 2, 3, 4])
node = onnx.helper.make_node("Mean",
inputs=["0", "1", "2", "3", "4"],
outputs=["mean"])
return ([node], [data_0, data_1, data_2, data_3, data_4], [mean])
@onnx_test
def mean_fp16_test():
data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [1, 2, 3])
data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 2, 3])
data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 2, 3])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT16,
[1, 2, 3])
node = onnx.helper.make_node("Mean",
inputs=["0", "1", "2"],
outputs=["mean"])
return ([node], [data_0, data_1, data_2], [mean])
@onnx_test
def mean_invalid_broadcast_test():
data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3])
data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2, 3])
data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 2, 4])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [1, 2, 3])
node = onnx.helper.make_node("Mean",
inputs=["0", "1", "2"],
outputs=["mean"])
return ([node], [data_0, data_1, data_2], [mean])
@onnx_test
def mean_single_input_test():
data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [1, 2, 3])
node = onnx.helper.make_node("Mean", inputs=["0"], outputs=["mean"])
return ([node], [data_0], [mean])
@onnx_test
def mean_test():
data = [
helper.make_tensor_value_info(str(i), TensorProto.DOUBLE, [2, 2, 2])
for i in range(10)
]
data_names = [str(i) for i in range(10)]
mean = helper.make_tensor_value_info('mean', TensorProto.DOUBLE, [2, 2, 2])
node = onnx.helper.make_node("Mean", inputs=data_names, outputs=["mean"])
return ([node], data, [mean])
@onnx_test
def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......@@ -5000,6 +5074,25 @@ def unknown_aten_test():
return ([node], [x, y], [a])
@onnx_test
def upsample_linear_test():
scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32)
scales_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(
np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Upsample',
inputs=['X', '', 'scales'],
outputs=['Y'],
mode='linear')
return ([node], [X], [Y], [scales_tensor])
@onnx_test
def upsample_test():
scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32)
......
mean_broadcast_test:Ã

0
1
2
3
4mean"Meanmean_broadcast_testZ
0



Z
1




Z
2

Z
3

Z
4



b
mean




B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment