Commit 3ce7ad8b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge from develop branch and resolve merge conflicts

parents d0202590 b304d97d
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2);
return op.normalize_compute_shape({inputs.at(0)}); return op.normalize_compute_shape({inputs.at(0)});
} }
......
...@@ -9,7 +9,7 @@ namespace gpu { ...@@ -9,7 +9,7 @@ namespace gpu {
shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2);
return op.normalize_compute_shape({inputs.at(0)}); return op.normalize_compute_shape({inputs.at(0)});
} }
......
...@@ -108,12 +108,13 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -108,12 +108,13 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
srcs.push_back(src_file{fs::path{"main.cpp"}, srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())}); std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = auto args_hpp =
generate_args_hpp(options.reduced_inputs.empty() ? options.inputs : options.reduced_inputs); generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"}, srcs.push_back(src_file{fs::path{"args.hpp"},
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())}); std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror"; options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name()); auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
if(cos.size() != 1) if(cos.size() != 1)
......
...@@ -12,6 +12,8 @@ namespace migraphx { ...@@ -12,6 +12,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
...@@ -70,6 +72,14 @@ struct compiled_result ...@@ -70,6 +72,14 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
template <class F>
void par_compile(std::size_t n, F f)
{
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
void compile_ops::apply(module& m) const void compile_ops::apply(module& m) const
{ {
auto compilers = make_compilers(pointwise_compiler{}); auto compilers = make_compilers(pointwise_compiler{});
...@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const ...@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; }); compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
} }
std::vector<compiled_result> results(compiles.size()); std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); }); par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results) for(const auto& cr : results)
{ {
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs()); m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
......
...@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__( ...@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp> #include <migraphx/kernels/pointwise.hpp>
#include <args.hpp> #include <args.hpp>
using namespace migraphx; namespace migraphx {
${preamble} ${preamble}
...@@ -32,6 +32,8 @@ __global__ void kernel(${params}) ...@@ -32,6 +32,8 @@ __global__ void kernel(${params})
} }
} // namespace migraphx
int main() {} int main() {}
)__migraphx__"; )__migraphx__";
...@@ -46,7 +48,7 @@ operation compile_pointwise(context&, ...@@ -46,7 +48,7 @@ operation compile_pointwise(context&,
options.local = 1024; options.local = 1024;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.reduced_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"params", enum_params(inputs.size(), "void * private_p")},
...@@ -60,8 +62,20 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu ...@@ -60,8 +62,20 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
{ {
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"})); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
return compile_pointwise((ctx), inputs, "&" + name, g.str()); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
} }
} // namespace gpu } // namespace gpu
......
...@@ -50,7 +50,7 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const ...@@ -50,7 +50,7 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
options.inputs = io_shapes; options.inputs = io_shapes;
options.output = out_s; options.output = out_s;
options.kernel_name = "roialign_kernel"; options.kernel_name = "roialign_kernel";
options.reduced_inputs = io_shapes; options.virtual_inputs = io_shapes;
// sampling_ratio // sampling_ratio
assert(val.contains("sampling_ratio")); assert(val.contains("sampling_ratio"));
......
...@@ -75,8 +75,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype( ...@@ -75,8 +75,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
{ {
index_int groups = (n + local - 1) / local; index_int groups = (n + local - 1) / local;
index_int nglobal = std::min<index_int>(256, groups) * local; // max possible number of blocks is set to 1B (1,073,741,824)
index_int nglobal = std::min<index_int>(1073741824, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) __device__ { launch(stream, nglobal, local)([=](auto idx) __device__ {
......
...@@ -20,34 +20,58 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -20,34 +20,58 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const index_int max_block_size = 256; const index_int max_block_size = 128;
const index_int block_size = compute_block_size(batch_item_num, max_block_size); const index_int block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
batch_shape.elements() * block_size, type init = lowest();
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size); if(axis == batch_lens.size() - 1)
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; {
type init = lowest(); gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto batch_max = block_reduce<max_block_size>( auto start_loc = i / block_size * batch_item_num;
idx, max{}, init, batch_item_num, [&](auto j) __device__ { auto batch_max = block_reduce<max_block_size>(
data_idx[axis] = j; idx, max{}, init, batch_item_num, [&](auto j) __device__ {
return input[data_idx]; return input[start_loc + j];
}); });
auto batch_sum = block_reduce<max_block_size>(
idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max;
return ::exp(to_hip_type(val));
});
auto batch_sum = idx.local_stride(batch_item_num, [&](auto j) __device__ {
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { auto val = input[start_loc + j] - batch_max;
data_idx[axis] = j; output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
auto val = input[data_idx] - batch_max; });
return ::exp(to_hip_type(val));
}); });
}
else
{
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
idx.local_stride(batch_item_num, [&](auto j) __device__ { auto batch_sum = block_reduce<max_block_size>(
data_idx[axis] = j; idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[data_idx] - batch_max; data_idx[axis] = j;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum; auto val = input[data_idx] - batch_max;
}); return ::exp(to_hip_type(val));
}); });
idx.local_stride(batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
});
});
}
}); });
} }
......
...@@ -62,6 +62,8 @@ struct fusion ...@@ -62,6 +62,8 @@ struct fusion
keep_alive(std::move(t)); keep_alive(std::move(t));
} }
bool empty() const { return fp == nullptr; }
op_t operator[](std::size_t i) const op_t operator[](std::size_t i) const
{ {
assert(fp); assert(fp);
...@@ -125,12 +127,11 @@ struct fusion ...@@ -125,12 +127,11 @@ struct fusion
return shape{shape::int8_type, {ws_size}}; return shape{shape::int8_type, {ws_size}};
} }
void compile(context& ctx) bool compile(context& ctx)
{ {
assert(fp); assert(fp);
auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()); return miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()) ==
if(status != miopenStatusSuccess) miopenStatusSuccess;
MIGRAPHX_THROW("Compiling fusion plan failed");
} }
argument execute(context& ctx, argument execute(context& ctx,
...@@ -169,7 +170,7 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins) ...@@ -169,7 +170,7 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
{ {
const auto device_name = split_string(get_device_name(), ':').front(); const auto device_name = trim(split_string(get_device_name(), ':').front());
if(not contains(get_supported_archs(), device_name)) if(not contains(get_supported_archs(), device_name))
return false; return false;
if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{})) if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{}))
...@@ -561,6 +562,122 @@ struct find_mul_add_relu ...@@ -561,6 +562,122 @@ struct find_mul_add_relu
} }
}; };
struct miopen_fusion
{
struct fuse_op_data
{
operation op;
float alpha = 1;
float beta = 0;
};
struct fuse_op : fuse_op_data, reflect_equality<fuse_op>, reflect_stream<fuse_op>
{
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"), f(self.alpha, "alpha"), f(self.beta, "beta"));
}
};
std::vector<fuse_op> ops = {};
fusion f = {};
std::function<void(context&, const fusion&, const std::vector<argument>&)> execute;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ops, "ops"));
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
value compile(context& ctx, const shape&, std::vector<shape> inputs)
{
// Compensate for allocation
inputs.pop_back();
std::size_t i = 0;
f = fusion(inputs[i]);
i++;
std::vector<std::function<void(const fused_operator_args&, const std::vector<argument>&)>>
invokers;
for(auto&& fop : ops)
{
if(i > inputs.size())
{
f = {};
return {};
}
if(fop.op.name() == "convolution")
{
auto* mop = f.create_conv(any_cast<op::convolution>(fop.op), inputs[i]);
invokers.push_back(
[=](const fused_operator_args& fargs, const std::vector<argument>& args) {
miopenSetOpArgsConvForward(
fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit());
});
i++;
}
else if(fop.op.name() == "add")
{
auto* mop = f.create_bias(inputs[i]);
invokers.push_back(
[=](const fused_operator_args& fargs, const std::vector<argument>& args) {
miopenSetOpArgsBiasForward(
fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit());
});
i++;
}
else if(fop.op.name() == "relu")
{
auto* mop = f.create_relu();
invokers.push_back([=](const fused_operator_args& fargs,
const std::vector<argument>&) {
miopenSetOpArgsActivForward(fargs.get(), mop, &fop.alpha, &fop.beta, 0, 0, 0);
});
}
else
{
f = {};
return {};
}
}
if(not f.compile(ctx))
{
f = {};
return {};
}
execute = [invokers](context& c, const fusion& ff, const std::vector<argument>& args) {
auto fargs = make_fused_args();
for(auto&& invoker : invokers)
invoker(fargs, args);
ff.execute(c, fargs, args.front(), args.back());
};
return {{"workspace", f.get_workspace(ctx).bytes()}};
}
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& inputs)
{
if(not f.empty())
return;
auto v = compile(ctx, output_shape, inputs);
if(not v.is_object())
MIGRAPHX_THROW("Failed to compile fusion plan");
}
std::string name() const { return "gpu::miopen_fusion"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(ops.empty())
return {};
// TODO: Check number of arguments
return ops.front().op.compute_shape({inputs[0], inputs[1]});
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
execute(ctx, f, args);
return args.back();
}
};
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
...@@ -596,7 +713,8 @@ struct miopen_conv_bias ...@@ -596,7 +713,8 @@ struct miopen_conv_bias
f = fusion(inputs[0]); f = fusion(inputs[0]);
conv = f.create_conv(op, inputs[1]); conv = f.create_conv(op, inputs[1]);
bias = f.create_bias(inputs[3]); bias = f.create_bias(inputs[3]);
f.compile(ctx); if(not f.compile(ctx))
MIGRAPHX_THROW("Failed to compile fusion plan");
} }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
...@@ -683,6 +801,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r) ...@@ -683,6 +801,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op")
return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s);
});
}
template <class... Ms>
auto conv_bias_pointwise(Ms... ms)
{
return precompile_name("pointwise")(
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")),
ms...);
}
struct find_conv_bias struct find_conv_bias
{ {
context* ctx = nullptr; context* ctx = nullptr;
...@@ -709,6 +846,46 @@ struct find_conv_bias_relu ...@@ -709,6 +846,46 @@ struct find_conv_bias_relu
} }
}; };
struct find_conv_pointwise
{
context* ctx = nullptr;
auto matcher() const
{
return precompile_name("pointwise")(
match::nargs(3),
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")));
}
void apply(module& m, match::matcher_result r) const
{
auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"];
auto ins = r.result;
auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto alloc_ins = ins->inputs().back();
module_ref pm = ins->module_inputs().front();
miopen_fusion op{};
op.ops.push_back({{conv_op}});
for(auto&& i : *pm)
{
if(i.name()[0] == '@')
continue;
auto inputs = to_shapes(i.inputs());
op.ops.push_back({{i.get_operator()}});
}
std::vector<instruction_ref> inputs = {input_ins, weights_ins, bias_ins, alloc_ins};
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(inputs));
if(not v.is_object())
return;
m.replace_instruction(ins, op, inputs);
}
};
struct find_gemm_add struct find_gemm_add
{ {
auto matcher() const auto matcher() const
...@@ -778,6 +955,7 @@ void fuse_ops::apply(module& p) const ...@@ -778,6 +955,7 @@ void fuse_ops::apply(module& p) const
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
find_layernorm{}, find_layernorm{},
find_conv_pointwise{ctx},
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_gelu{}, find_add_gelu{},
......
...@@ -16,7 +16,7 @@ struct hip_compile_options ...@@ -16,7 +16,7 @@ struct hip_compile_options
shape output; shape output;
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> reduced_inputs = {}; std::vector<shape> virtual_inputs = {};
}; };
operation compile_hip_code_object(const std::string& content, hip_compile_options options); operation compile_hip_code_object(const std::string& content, hip_compile_options options);
......
...@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
size_t batch_item_num = batch_lens[axis]; size_t batch_item_num = batch_lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens}; migraphx::shape batch_shape{arg_shape.type(), batch_lens};
migraphx::shape std_arg_shape{arg_shape.type(), arg_shape.lens()};
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { hip_visit_all(arg, std_arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto* output = device_cast(result.get<int64_t>().data()); auto* output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch. // use one block for items in one batch.
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
namespace migraphx { namespace migraphx {
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
// Workaround hip's broken abort on device code // Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__ #ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE // NOLINTNEXTLINE
...@@ -14,19 +17,67 @@ namespace migraphx { ...@@ -14,19 +17,67 @@ namespace migraphx {
#define MIGRAPHX_HIP_NORETURN [[noreturn]] #define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif #endif
namespace debug {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
template <size_t N>
struct print_buffer
{
char buffer[N + 1] = {0};
char* pos = buffer;
constexpr void append(char c)
{
if(c == 0)
return;
if(pos < buffer + N)
{
*pos = c;
pos++;
}
}
template <size_t M>
constexpr void append(const char (&array)[M])
{
for(int i = 0; i < M; i++)
append(array[i]);
}
};
template <class... Ts>
__host__ __device__ void print(const Ts&... xs)
{
const auto size = (sizeof(xs) + ...);
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer);
}
} // namespace debug
// noreturn cannot be used on this function because abort in hip is broken // noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function) assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& function)
{ {
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion); // printf is broken on hip with more than one argument, so use a simple print functions instead
debug::print(file, ":", line, ": ", function, ": assertion '", assertion, "' failed.\n");
// printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort(); abort();
} }
#ifdef MIGRAPHX_DEBUG #ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \ #define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto... xs) { \ ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(xs...); \ assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)) }(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
#else #else
#define MIGRAPHX_ASSERT(cond) #define MIGRAPHX_ASSERT(cond)
#endif #endif
......
...@@ -16,6 +16,19 @@ struct swallow ...@@ -16,6 +16,19 @@ struct swallow
template <index_int> template <index_int>
using ignore = swallow; using ignore = swallow;
template <class... Fs>
struct overloaded : Fs...
{
using Fs::operator()...;
overloaded(Fs... fs) : Fs(fs)... {}
};
template <class... Fs>
overloaded<Fs...> overload(Fs... fs)
{
return {fs...};
}
namespace detail { namespace detail {
template <class R> template <class R>
...@@ -124,12 +137,48 @@ constexpr void each_args(F) ...@@ -124,12 +137,48 @@ constexpr void each_args(F)
{ {
} }
template <class F, class T>
constexpr auto fold_impl(F&&, T&& x)
{
return static_cast<T&&>(x);
}
template <class F, class T, class U, class... Ts>
constexpr auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{
return fold_impl(f, f(static_cast<T&&>(x), static_cast<U&&>(y)), static_cast<Ts&&>(xs)...);
}
template <class F>
constexpr auto fold(F f)
{
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
}
template <class... Ts> template <class... Ts>
auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{
return p1([&](auto... xs) {
return p2([&](auto... ys) {
auto c = [&](auto x, auto y) -> int {
if(compare(x, y))
return 1;
else if(compare(y, x))
return -1;
else
return 0;
};
return fold([](auto x, auto y) { return x ? x : y; })(c(xs, ys)..., 0);
});
});
}
template <index_int N> template <index_int N>
constexpr auto arg_c() constexpr auto arg_c()
{ {
...@@ -168,9 +217,13 @@ constexpr auto transform_args(F f, Fs... fs) ...@@ -168,9 +217,13 @@ constexpr auto transform_args(F f, Fs... fs)
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); }; return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
} }
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); }) [](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
...@@ -13,6 +13,7 @@ struct integral_constant ...@@ -13,6 +13,7 @@ struct integral_constant
using type = integral_constant; using type = integral_constant;
constexpr operator value_type() const noexcept { return value; } constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; } constexpr value_type operator()() const noexcept { return value; }
static constexpr type to() { return {}; }
}; };
// NOLINTNEXTLINE // NOLINTNEXTLINE
......
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
template <class T>
constexpr T as_float(T x)
{
return x;
}
} // namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_VEC(name) \
template <class... Ts, MIGRAPHX_REQUIRES(is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) \
{ \
return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_DEVICE_MATH(abs, ::abs)
MIGRAPHX_DEVICE_MATH(acos, ::acos)
MIGRAPHX_DEVICE_MATH(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH(asin, ::asin)
MIGRAPHX_DEVICE_MATH(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH(atan, ::atan)
MIGRAPHX_DEVICE_MATH(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH(cos, ::cos)
MIGRAPHX_DEVICE_MATH(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin)
MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
// Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
MIGRAPHX_DEVICE_MATH_FOR(float, acosh, ::acoshf)
MIGRAPHX_DEVICE_MATH_FOR(float, asin, ::asinf)
MIGRAPHX_DEVICE_MATH_FOR(float, asinh, ::asinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, atan, ::atanf)
MIGRAPHX_DEVICE_MATH_FOR(float, atanh, ::atanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, cos, ::cosf)
MIGRAPHX_DEVICE_MATH_FOR(float, cosh, ::coshf)
MIGRAPHX_DEVICE_MATH_FOR(float, rsqrt, ::rsqrtf)
MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos)
MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b)
{
return cond ? a : b;
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
MIGRAPHX_DEVICE_MATH_VEC(asin)
MIGRAPHX_DEVICE_MATH_VEC(asinh)
MIGRAPHX_DEVICE_MATH_VEC(atan)
MIGRAPHX_DEVICE_MATH_VEC(atanh)
MIGRAPHX_DEVICE_MATH_VEC(ceil)
MIGRAPHX_DEVICE_MATH_VEC(cos)
MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin)
MIGRAPHX_DEVICE_MATH_VEC(sinh)
MIGRAPHX_DEVICE_MATH_VEC(sqrt)
MIGRAPHX_DEVICE_MATH_VEC(tan)
MIGRAPHX_DEVICE_MATH_VEC(tanh)
MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto max(const T& a, const U& b)
{
return where(a < b, b, a);
}
template <class T, class U>
constexpr auto min(const T& a, const U& b)
{
return where(a > b, b, a);
}
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) -> T { return x; });
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP
...@@ -3,19 +3,45 @@ ...@@ -3,19 +3,45 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/preload.hpp> #include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/vectorize.hpp> #include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/args.hpp> #include <migraphx/kernels/args.hpp>
namespace migraphx { namespace migraphx {
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) { preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) { idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i); auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = f(ps[multi_idx]...); out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
}); });
}); });
} }
...@@ -23,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) ...@@ -23,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
template <class F, class... Ts> template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps) __device__ void pointwise(F f, Ts*... ps)
{ {
auto t = transform_args(make_tensors(), rotate_last()); auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize());
t(ps...)([&](auto... xs) { t(ps...)([&](auto... xs) {
auto idx = make_index(); auto idx = make_index();
pointwise_tensor(idx, f, xs...); pointwise_tensor(idx, f, xs...);
......
...@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss) ...@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
} }
template <class T, class... Shapes> template <class T, class... Shapes>
constexpr index_int compute_preload_size(Shapes...) constexpr index_int compute_preload_size_c(Shapes...)
{ {
index_int size = 0; index_int size = 0;
traverse_preload<T>(Shapes{}...)( traverse_preload<T>(Shapes{}...)(
...@@ -37,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...) ...@@ -37,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...)
return size; return size;
} }
template <class T, class... Shapes>
constexpr auto compute_preload_size(Shapes...)
{
return _c<compute_preload_size_c<T>(Shapes{}...)>;
}
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{ {
...@@ -48,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -48,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
[&](auto x, auto offset, auto copy) { [&](auto x, auto offset, auto copy) {
if constexpr(copy) if constexpr(copy)
{ {
auto v = vectorize(x); if constexpr(decltype(tensor_vec_size(x)){} == 0)
auto b = as_vec(tensor_vec_size(v), buffer + offset); {
idx.local_stride(v.get_shape().element_space(), auto v = vectorize(x);
[&](auto i) { b[i] = v.data()[i]; }); auto b = as_vec(tensor_vec_size(v), buffer + offset);
return x.with(buffer + offset); idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
}
else
{
auto b = as_vec(tensor_vec_size(x), buffer + offset);
idx.local_stride(x.get_shape().element_space(),
[&](auto i) { b[i] = x.data()[i]; });
return x.with(b);
}
} }
else else
{ {
...@@ -78,7 +94,7 @@ template <class T, class... Ts> ...@@ -78,7 +94,7 @@ template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs) __device__ auto preload(index idx, Ts... xs)
{ {
using type = typename remove_vec<T>::type; using type = typename remove_vec<T>::type;
constexpr auto size = compute_preload_size<type>(xs.get_shape()...); constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){};
const index_int max_size = 512 * sizeof(type); const index_int max_size = 512 * sizeof(type);
return [=](auto f) { return [=](auto f) {
if constexpr(size > 0 and size < max_size) if constexpr(size > 0 and size < max_size)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx { namespace migraphx {
...@@ -24,6 +25,38 @@ constexpr auto vec_size() ...@@ -24,6 +25,38 @@ constexpr auto vec_size()
return decltype(vec_size(T{})){}; return decltype(vec_size(T{})){};
} }
template <class... Ts>
constexpr auto is_any_vec()
{
if constexpr(sizeof...(Ts) == 0)
return false_type{};
else
return bool_constant<((vec_size<Ts>() + ...) > 0)>{};
}
template <class T, class I>
constexpr auto vec_at(T x, I i)
{
if constexpr(vec_size<T>() == 0)
return x;
else
{
MIGRAPHX_ASSERT(i < vec_size<T>());
return x[i];
}
}
template <class... Ts>
constexpr auto common_vec_size()
{
return fold([](auto x, auto y) {
if constexpr(x > y)
return x;
else
return y;
})(vec_size<Ts>()...);
}
template <index_int N, class T> template <index_int N, class T>
__device__ __host__ auto as_vec(T* x) __device__ __host__ auto as_vec(T* x)
{ {
...@@ -33,5 +66,28 @@ __device__ __host__ auto as_vec(T* x) ...@@ -33,5 +66,28 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts>
constexpr auto vec_transform(Ts... xs)
{
return [=](auto f) {
if constexpr(is_any_vec<Ts...>())
{
using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>();
safe_vec<type, size> result = {0};
for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...);
return result;
}
else
{
return f(xs...);
}
};
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP #endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
...@@ -7,40 +7,70 @@ ...@@ -7,40 +7,70 @@
namespace migraphx { namespace migraphx {
template <class T> template <class T>
constexpr auto tensor_vec_size(T) constexpr auto tensor_vec_size()
{ {
return vec_size<typename T::type>(); return vec_size<typename T::type>();
} }
template <index_int N, class Shape> template <class T>
constexpr auto as_vec_shape(Shape s) constexpr auto tensor_vec_size(T)
{ {
auto lens = transform(s.lens, s.strides, [](auto len, auto stride) { return tensor_vec_size<T>();
if(stride == 1) }
return len / N;
else template <index_int N, class Shape, class Axis>
return len; constexpr auto shape_step(Shape s, Axis)
}); {
auto strides = transform(s.strides, [](auto stride) { static_assert(N > 0, "Vector size must be non-zero");
if(stride == 1) return sequence(s.lens.size(), [&](auto... is) {
return stride; auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
return stride / N; constexpr auto axis = Axis::to();
MIGRAPHX_ASSERT(i != 0);
MIGRAPHX_ASSERT(j != axis or i % N == 0);
if(j == axis)
return i / N;
else
return i;
});
auto strides = transform(s.strides, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis::to();
// If stride of the axis is zero then we dont need to adjust the other strides
if(Shape{}.strides[axis] == 0)
return i;
MIGRAPHX_ASSERT(j == axis or i % N == 0);
if(j == axis)
return i;
else
return i / N;
});
MIGRAPHX_ASSERT(make_shape(lens, strides).elements() * N == s.elements());
MIGRAPHX_ASSERT(strides[Axis{}] == 0 or
make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
}); });
MIGRAPHX_ASSERT(make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
} }
template <index_int N, class T> // Bools can not be used as a vector type so convert it to uint8
__device__ __host__ auto as_vec(T x) template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N == 0)
return x; return x;
else else
return make_tensor_view(as_vec<N>(x.data()), as_vec_shape<N>(x.get_shape())); return make_tensor_view(as_vec<N>(remove_bool(x.data())),
shape_step<N>(x.get_shape(), axis));
} }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis) constexpr auto tensor_step(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N == 0)
{ {
...@@ -49,17 +79,8 @@ constexpr auto tensor_step(T x, Axis) ...@@ -49,17 +79,8 @@ constexpr auto tensor_step(T x, Axis)
else else
{ {
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
MIGRAPHX_ASSERT(s.strides[Axis{}] == 0); MIGRAPHX_ASSERT(s.strides[axis] == 0);
return sequence(x.get_shape().lens.size(), [&](auto... is) { return make_tensor_view(x.data(), shape_step<N>(s, axis));
auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis{};
if(j == axis)
return i / N;
else
return i;
});
return make_tensor_view(x.data(), make_shape(lens, s.strides));
});
} }
} }
...@@ -69,42 +90,71 @@ __device__ __host__ auto as_vec(IntegralConstant ic, T&& x) ...@@ -69,42 +90,71 @@ __device__ __host__ auto as_vec(IntegralConstant ic, T&& x)
return as_vec<ic>(x); return as_vec<ic>(x);
} }
template <class... Shapes> template <class Shape>
constexpr index_int find_vector_axis(Shapes... ss) constexpr index_int find_vector_axis_c(Shape s)
{ {
// Find the fastest axis that is not broadcasted
index_int axis = 0; index_int axis = 0;
bool b = false; for(index_int i = 1; i < s.lens.size(); i++)
{
if(s.strides[i] == 0)
continue;
if(s.strides[axis] == 0 or
pack_compare(less{}, pack(s.strides[i], s.lens[i]), pack(s.strides[axis], s.lens[axis])))
axis = i;
}
return axis;
}
template <class... Shapes>
constexpr index_int find_vector_axis_c(Shapes... ss)
{
const bool all_broadcasted = (ss.broadcasted() and ...);
index_int axis = 0;
bool b = false;
by([&](auto s) { by([&](auto s) {
if(s.broadcasted() or b) if(b)
return; return;
auto it = find(s.strides.begin(), s.strides.end(), 1); // Skip broadcasted shapes if there are shapes not broadcasted
if(it == s.strides.end()) if(not all_broadcasted and s.broadcasted())
return; return;
axis = it - s.strides.begin(); axis = find_vector_axis_c(s);
b = true; if(s.strides[axis] == 1)
b = true;
})(ss...); })(ss...);
if(not b)
return -1;
return axis; return axis;
} }
template <class... Shapes>
constexpr auto find_vector_axis(Shapes...)
{
return _c<find_vector_axis_c(Shapes{}...)>;
}
template <index_int N, class Axis, class... Shapes> template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis axis, Shapes... ss) constexpr auto is_vectorizable_c(Axis axis, Shapes... ss)
{ {
return (((ss.lens[axis] % N) == 0 and (ss.strides[axis] == 1 or ss.strides[axis] == 0)) and return ((axis < ss.lens.size() and ss.lens[axis] % N == 0 and
// Only vectorize broadcasted types with stride 0, since this causes issues in the
// preloader
((not ss.broadcasted() and ss.strides[axis] == 1) or ss.strides[axis] == 0)) and
...); ...);
} }
template <index_int N, class... Shapes> template <index_int N, class Axis, class... Shapes>
constexpr bool is_vectorizable(Shapes... ss) constexpr auto is_vectorizable(Axis, Shapes...)
{ {
return (is_vectorizable<N>(ss, find_vector_axis(ss)) and ...); return _c<is_vectorizable_c<N>(Axis::to(), Shapes{}...)>;
} }
template <class P> template <class P>
constexpr auto find_vectorize_size(P pred) constexpr auto find_vectorize_size(P pred)
{ {
if constexpr(pred(_c<4>)) if constexpr(decltype(pred(_c<4>)){})
return _c<4>; return _c<4>;
else if constexpr(pred(_c<2>)) else if constexpr(decltype(pred(_c<2>)){})
return _c<2>; return _c<2>;
else else
return _c<0>; return _c<0>;
...@@ -113,11 +163,12 @@ constexpr auto find_vectorize_size(P pred) ...@@ -113,11 +163,12 @@ constexpr auto find_vectorize_size(P pred)
template <class T> template <class T>
__host__ __device__ auto vectorize(T x) __host__ __device__ auto vectorize(T x)
{ {
if constexpr(vec_size<T>() == 0) if constexpr(tensor_vec_size<T>() == 0)
{ {
constexpr auto axis = find_vector_axis(x.get_shape());
constexpr auto n = constexpr auto n =
find_vectorize_size([&](auto i) { return _c<is_vectorizable<i>(x.get_shape())>; }); find_vectorize_size([&](auto i) { return is_vectorizable<i>(axis, x.get_shape()); });
return as_vec<n>(x); return as_vec<n>(x, axis);
} }
else else
{ {
...@@ -125,34 +176,46 @@ __host__ __device__ auto vectorize(T x) ...@@ -125,34 +176,46 @@ __host__ __device__ auto vectorize(T x)
} }
} }
template <class F, class... Ts>
inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = decltype(find_vector_axis(xs.get_shape()...)){};
constexpr auto n = find_vectorize_size(
[&](auto i) { return is_vectorizable<i>(axis, xs.get_shape()...); });
by(
[&](auto x) {
constexpr auto s = decltype(x.get_shape()){};
if constexpr(axis < s.strides.size())
{
MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1);
MIGRAPHX_ASSERT(s.lens[axis] > 0);
MIGRAPHX_ASSERT(n == 0 or s.lens[axis] % n == 0);
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x, axis);
}
else
{
return x;
}
},
f)(xs...);
}
else
{
f(xs...);
}
}
inline __device__ __host__ auto auto_vectorize() inline __device__ __host__ auto auto_vectorize()
{ {
return [](auto... xs) { return [](auto... xs) { return [=](auto f) { auto_vectorize_impl(f, xs...); }; };
return [=](auto f) {
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = find_vector_axis(xs.get_shape()...);
constexpr auto n = find_vectorize_size(
[&](auto i) { return _c<is_vectorizable<i>(axis, xs.get_shape()...)>; });
by(
[&](auto x) {
constexpr auto s = x.get_shape();
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x);
},
f)(xs...);
}
else
{
f(xs...);
}
};
};
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -101,4 +101,38 @@ TEST_CASE(after_param_broadcast) ...@@ -101,4 +101,38 @@ TEST_CASE(after_param_broadcast)
EXPECT(not m.get_output_shapes().back().broadcasted()); EXPECT(not m.get_output_shapes().back().broadcasted());
} }
TEST_CASE(two_transpose_gather)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto ind = m1.add_parameter("ind", {migraphx::shape::float_type, {2, 3}});
auto td = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto sd = m1.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), td);
auto bd =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), bd, ind);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto ind = m2.add_parameter("ind", {migraphx::shape::float_type, {2, 3}});
auto td = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td);
auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd);
auto bd =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd);
auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment