"src/nni_manager/vscode:/vscode.git/clone" did not exist on "6f760ab632d3a69c70cade82724096899f107f2e"
Commit 4957715b authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into dev2

parents f99a3036 4ec8209f
...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Flatten", "flatten"}, {"Flatten", "flatten"},
{"Floor", "floor"}, {"Floor", "floor"},
{"Gather", "gather"}, {"Gather", "gather"},
{"GatherND", "gathernd"},
{"Identity", "identity"}, {"Identity", "identity"},
{"IsNaN", "isnan"}, {"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"}, {"LeakyRelu", "leaky_relu"},
......
...@@ -24,14 +24,17 @@ struct parse_mean : op_parser<parse_mean> ...@@ -24,14 +24,17 @@ struct parse_mean : op_parser<parse_mean>
auto divisor = info.add_literal( auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}}); migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
return std::accumulate(args.begin(), args.end(), args[0], [&](auto& mean, auto& data_i) { // TODO: Only divide when using floating-point
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation return std::accumulate(args.begin() + 1,
data_i = info.add_broadcastable_binary_op("div", data_i, divisor); args.end(),
info.add_broadcastable_binary_op("div", args[0], divisor),
if(data_i != args[0]) [&](auto mean, auto data_i) {
return info.add_broadcastable_binary_op("add", mean, data_i); // Pre-divide each tensor element-wise by n to reduce risk of
return data_i; // overflow during summation
}); auto div =
info.add_broadcastable_binary_op("div", data_i, divisor);
return info.add_broadcastable_binary_op("add", mean, div);
});
} }
}; };
......
...@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -273,6 +273,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("op"), py::arg("op"),
py::arg("args"), py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{}) py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_literal",
[](migraphx::module& mm, py::buffer data) {
py::buffer_info info = data.request();
auto literal_shape = to_shape(info);
return mm.add_literal(literal_shape, reinterpret_cast<char*>(info.ptr));
},
py::arg("data"))
.def( .def(
"add_parameter", "add_parameter",
[](migraphx::module& mm, const std::string& name, const migraphx::shape shape) { [](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
......
...@@ -995,7 +995,7 @@ struct find_split_transpose ...@@ -995,7 +995,7 @@ struct find_split_transpose
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front(); auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
auto it = std::find(perm.begin(), perm.end(), axis); auto it = std::find(perm.begin(), perm.end(), axis);
assert(it != perm.end()); assert(it != perm.end());
auto axis_new = static_cast<int64_t>(std::distance(perm.begin(), it)); int64_t axis_new = std::distance(perm.begin(), it);
for(auto in : split_outputs) for(auto in : split_outputs)
{ {
......
...@@ -7,7 +7,16 @@ ...@@ -7,7 +7,16 @@
#ifdef MIGRAPHX_DISABLE_OMP #ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#else #else
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
#include <omp.h> #include <omp.h>
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif #endif
namespace migraphx { namespace migraphx {
......
...@@ -319,7 +319,7 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>> ...@@ -319,7 +319,7 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
auto s = inputs.at(0); const auto& s = inputs.at(0);
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
argument argument
...@@ -357,7 +357,7 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>> ...@@ -357,7 +357,7 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
auto s = inputs.at(0); const auto& s = inputs.at(0);
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
......
...@@ -223,7 +223,7 @@ struct cpu_unary2 : auto_register_op<cpu_unary2<Op>> ...@@ -223,7 +223,7 @@ struct cpu_unary2 : auto_register_op<cpu_unary2<Op>>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); const auto& s = inputs.at(0);
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
......
...@@ -93,7 +93,7 @@ add_library(migraphx_device ...@@ -93,7 +93,7 @@ add_library(migraphx_device
) )
add_library(compile_for_gpu INTERFACE) add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns) target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument) target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE) check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE) if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device") message(STATUS "Enable -fhip-lambda-host-device")
......
...@@ -22,6 +22,7 @@ namespace gpu { ...@@ -22,6 +22,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC #if MIGRAPHX_USE_HIPRTC
...@@ -133,6 +134,7 @@ struct hiprtc_program ...@@ -133,6 +134,7 @@ struct hiprtc_program
std::vector<char> buffer(n); std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() == 0); assert(buffer.back() == 0);
// cppcheck-suppress returnDanglingLifetime
return {buffer.begin(), buffer.end() - 1}; return {buffer.begin(), buffer.end() - 1};
} }
...@@ -246,6 +248,16 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -246,6 +248,16 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
MIGRAPHX_THROW("Missing hsaco"); MIGRAPHX_THROW("Missing hsaco");
}; };
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
if(enabled(MIGRAPHX_GPU_DUMP_ASM{})) if(enabled(MIGRAPHX_GPU_DUMP_ASM{}))
{ {
......
...@@ -20,7 +20,7 @@ struct run_op : action<run_op> ...@@ -20,7 +20,7 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs); double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gathernd_compiler : compiler<gathernd_compiler>
{
std::vector<std::string> names() const { return {"gathernd"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gathernd_kernel";
options.virtual_inputs = inputs;
// batch_dims
assert(v.contains("batch_dims"));
auto batch_dims = v.at("batch_dims").to<int64_t>();
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims);
return compile_hip_code_object(gathernd_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
...@@ -28,7 +29,8 @@ ${preamble} ...@@ -28,7 +29,8 @@ ${preamble}
extern "C" { extern "C" {
__global__ void kernel(${params}) __global__ void kernel(${params})
{ {
pointwise(${lambda}, ${args}); auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
} }
} }
...@@ -41,40 +43,105 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -41,40 +43,105 @@ struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe(const std::vector<shape>& inputs) static std::size_t oversubscribe_if(bool b)
{ {
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); })) if(b)
return 1;
else
return 256; return 256;
else
return 1;
}
static std::size_t find_fast_axis(const std::vector<shape>& inputs)
{
auto permutation = find_permutation(inputs);
auto it = std::max_element(permutation.begin(), permutation.end());
return it - permutation.begin();
} }
static std::size_t vectorize_elements(const std::vector<shape>& inputs) static std::vector<bool> preload(std::size_t axis, const std::vector<shape>& inputs)
{ {
std::size_t n = inputs.front().elements(); const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return result;
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
return result;
}
static std::string preload_str(const std::vector<bool>& bs)
{
std::vector<std::string> bool_strs;
std::transform(bs.begin(), std::prev(bs.end()), std::back_inserter(bool_strs), [](bool b) {
if(b)
return "true";
return "false";
});
return "false, " + join_strings(bool_strs, ", ");
}
static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
{
// If all inputs is half then only use half2
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) { if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.packed() or s.broadcasted(); return s.type() == shape::half_type;
})) }))
{ return {2};
if((n % 4) == 0) return {4, 2};
return n / 4; }
else if((n % 2) == 0) static auto vectorize_elements(std::size_t axis, const std::vector<shape>& inputs)
return n / 2; {
} auto sizes = vector_sizes(inputs);
return n; std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
return 1;
});
return *std::min_element(max_vec_size.begin(), max_vec_size.end());
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, vectorize_elements(inputs), oversubscribe(inputs)));
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel, auto axis = find_fast_axis(options.virtual_inputs);
auto vec_size = vectorize_elements(axis, options.virtual_inputs);
auto preloads = preload(axis, options.virtual_inputs);
auto is_preloading =
std::accumulate(preloads.begin(), preloads.end(), false, std::logical_or<>{});
options.set_launch_params(v,
compute_global_for(ctx,
options.output.elements() / vec_size,
oversubscribe_if(not is_preloading)));
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"vec_size", std::to_string(vec_size)},
{"axis", std::to_string(axis)},
{"preloads", preload_str(preloads)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p) ...@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p)
{ {
make_tensors()(input_p, output_p)([](auto input, auto output) { make_tensors()(input_p, output_p)([](auto input, auto output) {
simple_reduce(${reduction}, ${init}, input, output, ${read}, ${write}); simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
}); });
} }
...@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input ...@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input
return get_reduce_elements(to_shapes(inputs)); return get_reduce_elements(to_shapes(inputs));
} }
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
rlens.begin(),
rlens.end(),
inputs.front().strides().begin(),
init,
[](auto x, auto y) { return std::min(x, y); },
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
return "block";
}
struct reduce_compiler : compiler<reduce_compiler> struct reduce_compiler : compiler<reduce_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
...@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler>
{ {
hip_compile_options options; hip_compile_options options;
auto reduce_elements = get_reduce_elements(inputs); auto reduce_elements = get_reduce_elements(inputs);
auto block_size = compute_block_size(reduce_elements, 256); auto algo = v.get("algo", get_reduce_algo(inputs));
options.set_launch_params( if(algo == "block")
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size); {
auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})}, {"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
{"write", v.get("write", identity)}, {"write", v.get("write", identity)},
{"algo", algo},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -52,9 +52,8 @@ struct scatternd_compiler : compiler<scatternd_compiler> ...@@ -52,9 +52,8 @@ struct scatternd_compiler : compiler<scatternd_compiler>
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
auto out_s = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.output = out_s; options.output = inputs.back();
options.kernel_name = "scatternd_kernel"; options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"}); auto reduction = "assign_" + v.get("reduction", std::string{"none"});
......
...@@ -21,6 +21,16 @@ struct greater ...@@ -21,6 +21,16 @@ struct greater
} }
}; };
template <class InputIt, class T, class BinaryOperation>
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
}
return init;
}
template <class InputIt, class OutputIt> template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{ {
......
...@@ -42,6 +42,32 @@ struct print_buffer ...@@ -42,6 +42,32 @@ struct print_buffer
pos++; pos++;
} }
} }
template <class T, class = decltype(T{} % 10, -T{})>
constexpr void append(T i)
{
if(i < 0)
{
append('-');
i = -i;
}
char c = (i % 10) + '0';
if(i > 9)
append(i / 10);
append(c);
}
constexpr void append(const char* str)
{
if(str == nullptr)
return;
int i = 512;
while(*str != 0 and i > 0)
{
append(*str);
str++;
i--;
}
}
template <size_t M> template <size_t M>
constexpr void append(const char (&array)[M]) constexpr void append(const char (&array)[M])
...@@ -54,14 +80,36 @@ struct print_buffer ...@@ -54,14 +80,36 @@ struct print_buffer
template <class... Ts> template <class... Ts>
__host__ __device__ void print(const Ts&... xs) __host__ __device__ void print(const Ts&... xs)
{ {
const auto size = (sizeof(xs) + ...); print_buffer<1024> buffer;
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...}; swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer); printf("%s", buffer.buffer);
} }
} // namespace debug } // namespace debug
struct source_location
{
int line = __builtin_LINE();
const char* file = __builtin_FILE();
const char* function = __builtin_FUNCTION();
};
template <class T>
struct source_location_capture
{
T x;
source_location loc;
template <class U, class = decltype(T(U{}))>
constexpr source_location_capture(U px, source_location ploc = source_location{})
: x(px), loc(ploc)
{
}
constexpr operator source_location() const { return loc; }
constexpr operator T() const { return x; }
};
// noreturn cannot be used on this function because abort in hip is broken // noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4> template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
...@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct ...@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
abort(); abort();
} }
template <class... Ts>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_location& loc,
Ts... xs)
{
debug::print(loc.file, ":", loc.line, ": ", loc.function, ": error: ", xs..., "\n");
abort();
}
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \ #define MIGRAPHX_ASSERT_FAIL(cond, ...) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \ ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \ assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__)) }(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG #ifdef MIGRAPHX_DEBUG
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK #define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK #define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false) #define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else #else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume #define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable #define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond) #define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif #endif
} // namespace migraphx } // namespace migraphx
......
...@@ -3,6 +3,14 @@ ...@@ -3,6 +3,14 @@
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/array.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
namespace migraphx { namespace migraphx {
struct swallow struct swallow
...@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs) ...@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs)
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
template <class G, class F>
constexpr auto join(G g, F f)
{
return f([=](auto... xs) { return g(xs...); });
}
template <class G, class F, class... Fs>
constexpr auto join(G g, F f, Fs... fs)
{
return f([=](auto... xs) { return join([=](auto... ys) { return g(xs..., ys...); }, fs...); });
}
template <class Compare, class P1, class P2> template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2) constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{ {
...@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic) ...@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic)
return arg_c<ic>(); return arg_c<ic>();
} }
inline constexpr auto rotate_last() template <class F>
constexpr auto make_transform(F f)
{ {
return [](auto... xs) { return [=](auto... xs) { return [=](auto g) { return f(g, xs...); }; };
return [=](auto&& f) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
};
};
} }
// An arg transformation takes the arguments and then a function to take the new arguments:
// transform(xs...)([](auto... ys) { ... })
// The transform_args function takes a list of transformations and continually applies them
template <class F> template <class F>
constexpr auto transform_args(F f) constexpr auto transform_args(F f)
{ {
return [=](auto... xs) { return f;
return [=](auto g) { return f(xs...)([&](auto... ys) { return g(ys...); }); };
};
} }
template <class F, class... Fs> template <class F, class... Fs>
constexpr auto transform_args(F f, Fs... fs) constexpr auto transform_args(F f, Fs... fs)
{ {
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); }; return make_transform([=](auto g, auto... xs) {
return f(xs...)([=](auto... ys) { return transform_args(fs...)(ys...)(g); });
});
} }
// NOLINTNEXTLINE // identity transform
#define MIGRAPHX_RETURNS(...) \ inline constexpr auto transform_args()
->decltype(__VA_ARGS__) { return __VA_ARGS__; } {
return make_transform([](auto f, auto... xs) { return f(xs...); });
}
// NOLINTNEXTLINE // Rotate the first argument to the last argument
#define MIGRAPHX_LIFT(...) \ inline constexpr auto rotate_last()
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)) {
return make_transform([](auto f, auto... xs) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
});
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T>
struct gathernd_settings
{
T batch_dims{};
};
template <class... Ts>
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class Settings>
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s)
{
auto ind = make_index();
auto batch_dims = s.batch_dims;
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto data_shape = data_t.get_shape();
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
std::multiplies<std::size_t>());
relative_slice_offset += index * size_from_slice_dims;
}
auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset;
output_t[i] = data_t[slice_offset + i % slice_size];
});
}
} // namespace migraphx
#endif
...@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x) ...@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x)
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) { idx.global_stride(out.get_shape().elements(),
idx.global_stride(out.get_shape().elements(), [&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); });
[&](auto i) { out[i] = implicit_conversion(f(ps[i]...)); });
});
} }
template <class F, class... Ts> template <class... Transforms>
__device__ void pointwise(F f, Ts*... ps) __device__ auto pointwise(index idx, Transforms... transforms)
{ {
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize()); return [=](auto f, auto*... ps) {
t(ps...)([&](auto... xs) { auto t = transform_args(make_tensors(), rotate_last(), transforms...);
auto idx = make_index(); t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); });
pointwise_tensor(idx, f, xs...); };
});
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx { namespace migraphx {
...@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{ {
if constexpr(decltype(tensor_vec_size(x)){} == 0) if constexpr(decltype(tensor_vec_size(x)){} == 0)
{ {
auto v = vectorize(x); auto v = auto_vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset); auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(), idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; }); [&](auto i) { b[i] = v.data()[i]; });
...@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs) ...@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs)
}; };
} }
inline __device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto out, auto... xs) {
preload<typename decltype(out)::type>(idx, xs...)([&](auto... ys) { f(out, ys...); });
});
}
template <bool B, class T>
__device__ auto preload_copy(index idx, T x)
{
return [=](auto f) {
if constexpr(B)
{
using type = typename T::type;
constexpr auto size = get_shape_c<T>{}.element_space();
__shared__ type buffer[size];
// TODO: Always vecotrize when size > 4, and then use a second loop for remainder
constexpr auto n = find_vectorize_size([&](auto i) { return (size % i) == 0; });
auto input = as_vec<n>(remove_bool(x.data()));
auto b = as_vec<n>(remove_bool(buffer));
idx.local_stride(size / n, [&](auto i) { b[i] = input[i]; });
return f(x.with(buffer));
}
else
{
return f(x);
}
};
}
template <bool... Bs>
__device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto... xs) {
auto invoke = [=](auto... ys) {
__syncthreads();
f(ys...);
};
join(invoke, preload_copy<Bs>(idx, xs)...);
});
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP #endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment