Commit 63387253 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into select_module_op

parents dd74a89a 962329f3
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
...@@ -60,6 +61,7 @@ struct reverse ...@@ -60,6 +61,7 @@ struct reverse
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1);
return inputs[0].with_lens(inputs[0].lens()); return inputs[0].with_lens(inputs[0].lens());
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
...@@ -46,6 +47,10 @@ struct slice ...@@ -46,6 +47,10 @@ struct slice
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends")); return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
} }
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in
* limits.
*/
value attributes() const value attributes() const
{ {
value normalize = value::object{}; value normalize = value::object{};
...@@ -65,14 +70,6 @@ struct slice ...@@ -65,14 +70,6 @@ struct slice
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
if(r < 0)
r += lens[axis];
return std::size_t(r);
}
auto compute_offset(const shape& s) const auto compute_offset(const shape& s) const
{ {
const std::vector<std::size_t>& lens = s.lens(); const std::vector<std::size_t>& lens = s.lens();
...@@ -83,14 +80,14 @@ struct slice ...@@ -83,14 +80,14 @@ struct slice
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis]; offset += starts[i] * strides[axis];
} }
} }
else else
{ {
for(std::size_t axis = 0; axis < lens.size(); axis++) for(std::size_t axis = 0; axis < lens.size(); axis++)
{ {
offset += fix_index(lens, axis, starts[axis]) * strides[axis]; offset += starts[axis] * strides[axis];
} }
} }
return offset; return offset;
...@@ -98,37 +95,81 @@ struct slice ...@@ -98,37 +95,81 @@ struct slice
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; check_shapes{inputs, *this, true}.has(1);
auto t = input_shape.type(); auto input_shape = inputs[0];
const auto& old_lens = input_shape.lens(); auto t = input_shape.type();
const auto& old_strides = input_shape.strides();
if(std::any_of( // TODO: When support for dynamic shapes is added to normalize_attributes,
axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); })) // remove this restriction.
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{ {
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range"); MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis ");
} }
if(starts.size() != axes.size() or axes.size() != ends.size()) // For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
std::vector<std::size_t> new_mins;
std::vector<std::size_t> new_opts;
std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides;
if(input_shape.dynamic())
{
old_lens = input_shape.max_lens();
new_mins = input_shape.min_lens();
new_opts = input_shape.opt_lens();
}
else
{ {
MIGRAPHX_THROW("SLICE: inconsistent sizes"); old_lens = input_shape.lens();
// For static shape (including during eval step after a dynamic input) the strides are
// indexed into the pre-slice array, so they are larger than the apparent size of the
// resulting shape.
old_strides = input_shape.strides();
} }
std::vector<std::size_t> new_lens = old_lens; std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
new_lens[axis] = size_t sliced_length = ends[i] - starts[i];
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]); // A Numpy indexing convention: a slice size larger than the actual dimension
// is legal and the "ends" value is clipped to the axis size
new_lens[axis] = std::min(new_lens[axis], sliced_length);
if(input_shape.dynamic())
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
// sliced_length, making use of TBD start/end values.
std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
auto sliced_opt_length = ends[i] - starts[i];
if(new_opts[axis] != 0)
new_opts[axis] = sliced_opt_length;
if(new_opts[axis] < new_mins[axis] or new_opts[axis] > new_lens[axis])
new_opts[axis] = 0;
}
}
if(input_shape.dynamic())
{
return shape{t, new_mins, new_lens, new_opts};
}
else
{
return shape{t, new_lens, old_strides};
} }
return shape{t, new_lens, old_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }}; auto offset = compute_offset(input.get_shape()) * dyn_out.computed_shape.type_size();
return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std::vector<int64_t> steps; std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice // to decide whether MIGRAPHX can handle this slice.
if(args.size() == 5) if(args.size() == 5)
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
...@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice> ...@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
} }
// If axes arg is not given, the default is all of them.
if(op.axes.empty()) if(op.axes.empty())
{ {
std::vector<int64_t> axes(args[0]->get_shape().lens().size()); std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0}); std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes; op.axes = axes;
} }
...@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert(op.axes.size() == op.starts.size()); assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size()); assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size())) for(auto i : range(steps.size()))
{ {
if(steps[i] >= 0) if(steps[i] >= 0)
...@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice> ...@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]); auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty()) if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins); ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; })) if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{ {
std::vector<int64_t> nsteps; std::vector<int64_t> nsteps;
......
...@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options) ...@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
assert(not this->is_compiled()); assert(not this->is_compiled());
this->impl->target_name = t.name(); this->impl->target_name = t.name();
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout}; options.trace = tracer{std::cout};
options.trace(*this); options.trace(*this);
options.trace(); options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace); run_passes(*this, passes, options.trace);
auto mods = this->get_modules(); auto mods = this->get_modules();
// Validate and finalize // Validate and finalize
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
......
...@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("is_compiled", &migraphx::program::is_compiled) .def("is_compiled", &migraphx::program::is_compiled)
.def( .def(
"compile", "compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) { [](migraphx::program& p,
const migraphx::target& t,
bool offload_copy,
bool fast_math,
bool exhaustive_tune) {
migraphx::compile_options options; migraphx::compile_options options;
options.offload_copy = offload_copy; options.offload_copy = offload_copy;
options.fast_math = fast_math; options.fast_math = fast_math;
options.exhaustive_tune = exhaustive_tune;
p.compile(t, options); p.compile(t, options);
}, },
py::arg("t"), py::arg("t"),
py::arg("offload_copy") = true, py::arg("offload_copy") = true,
py::arg("fast_math") = true) py::arg("fast_math") = true,
py::arg("exhaustive_tune") = false)
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); }) .def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def( .def(
"create_module", "create_module",
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# THE SOFTWARE. # THE SOFTWARE.
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc) list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip)
find_package(miopen) find_package(miopen)
# rocblas # rocblas
...@@ -170,21 +170,6 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp ...@@ -170,21 +170,6 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
get_filename_component(CMAKE_CXX_COMPILER_PATH "${CMAKE_CXX_COMPILER}" PATH)
if(NOT CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$")
find_program(MIGRAPHX_EXTRACT_KERNEL extractkernel
PATH_SUFFIXES bin
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATHS
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
)
endif()
message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
...@@ -220,7 +205,6 @@ else() ...@@ -220,7 +205,6 @@ else()
target_compile_definitions(migraphx_gpu PRIVATE target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}" "-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}" "-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
) )
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER) if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
...@@ -252,8 +236,6 @@ else() ...@@ -252,8 +236,6 @@ else()
message(STATUS "MIOpen does not have find mode api") message(STATUS "MIOpen does not have find mode api")
endif() endif()
# Workaround broken rocblas headers
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
......
...@@ -207,12 +207,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -207,12 +207,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
#else // MIGRAPHX_USE_HIPRTC #else // MIGRAPHX_USE_HIPRTC
bool is_hcc_compiler()
{
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "hcc");
return result;
}
bool is_hip_clang_compiler() bool is_hip_clang_compiler()
{ {
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++"); static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++");
...@@ -236,7 +230,7 @@ std::vector<std::vector<char>> ...@@ -236,7 +230,7 @@ std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
assert(not srcs.empty()); assert(not srcs.empty());
if(not is_hcc_compiler() and not is_hip_clang_compiler()) if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " + MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER))); std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
...@@ -246,16 +240,9 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -246,16 +240,9 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{})) if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g"; params += " -g";
params += " -c"; params += " -c";
if(is_hcc_compiler()) params += " --offload-arch=" + arch;
{ params += " --cuda-device-only";
params += " -amdgpu-target=" + arch; params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
else if(is_hip_clang_compiler())
{
params += " --offload-arch=" + arch;
params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG"; params += " -DMIGRAPHX_DEBUG";
...@@ -270,24 +257,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -270,24 +257,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(has_compiler_launcher()) if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER); compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif #endif
if(is_hcc_compiler())
compiler.process = [&](const fs::path& obj_path) -> fs::path {
process{MIGRAPHX_STRINGIZE(MIGRAPHX_EXTRACT_KERNEL) + std::string{" -i "} +
obj_path.string()}
.cwd(obj_path.parent_path());
for(const auto& entry : fs::directory_iterator{obj_path.parent_path()})
{
const auto& hsaco_path = entry.path();
if(not fs::is_regular_file(hsaco_path))
continue;
if(hsaco_path.extension() != ".hsaco")
continue;
return hsaco_path;
}
MIGRAPHX_THROW("Missing hsaco");
};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{ {
for(const auto& src : srcs) for(const auto& src : srcs)
......
...@@ -112,14 +112,8 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) ...@@ -112,14 +112,8 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
#ifdef MIGRAPHX_USE_CLANG_TIDY #ifdef MIGRAPHX_USE_CLANG_TIDY
#define MIGRAPHX_DEVICE_SHARED #define MIGRAPHX_DEVICE_SHARED
#else #else
// Workaround hcc's broken tile_static macro
#ifdef tile_static
#undef tile_static
#define MIGRAPHX_DEVICE_SHARED __attribute__((tile_static))
#else
#define MIGRAPHX_DEVICE_SHARED __shared__ #define MIGRAPHX_DEVICE_SHARED __shared__
#endif #endif
#endif
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -216,6 +216,10 @@ struct context ...@@ -216,6 +216,10 @@ struct context
return *current_device; return *current_device;
} }
bool get_exhaustive_tune_flag() const { return exhaustive_tune; }
void set_exhaustive_tune_flag(bool t) { exhaustive_tune = t; }
hip_device::stream& get_stream() { return get_current_device().get_stream(); } hip_device::stream& get_stream() { return get_current_device().get_stream(); }
hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); } hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); }
...@@ -338,7 +342,8 @@ struct context ...@@ -338,7 +342,8 @@ struct context
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events; std::vector<shared<hip_event_ptr>> events;
bool measure_perf = false; bool exhaustive_tune = false;
bool measure_perf = false;
// for event perf timing // for event perf timing
shared<hip_event_ptr> start_event = nullptr; shared<hip_event_ptr> start_event = nullptr;
shared<hip_event_ptr> stop_event = nullptr; shared<hip_event_ptr> stop_event = nullptr;
......
...@@ -175,8 +175,9 @@ struct miopen_convolution ...@@ -175,8 +175,9 @@ struct miopen_convolution
auto* miopen_stream_handle = ctx.get_stream().get_miopen(); auto* miopen_stream_handle = ctx.get_stream().get_miopen();
solution_ptr = find_solution(miopen_stream_handle, conv_problem.get()); solution_ptr = find_solution(
auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size); miopen_stream_handle, conv_problem.get(), ctx.get_exhaustive_tune_flag());
auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size"); MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size");
...@@ -233,7 +234,7 @@ struct miopen_convolution ...@@ -233,7 +234,7 @@ struct miopen_convolution
&perf, &perf,
workspace.implicit(), workspace.implicit(),
workspace_size, workspace_size,
false); ctx.get_exhaustive_tune_flag());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo; algo = perf.fwd_algo;
......
...@@ -75,12 +75,19 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr ...@@ -75,12 +75,19 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr
using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem); using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem);
using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution); using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution);
inline miopen_solution find_solution(miopenHandle_t handle, miopenProblem_t problem) inline miopen_solution
find_solution(miopenHandle_t handle, miopenProblem_t problem, bool tune = false)
{ {
miopenSolution_t solution; miopenSolution_t solution;
size_t found = 0; size_t found = 0;
auto status = miopenFindSolutions(handle, problem, nullptr, &solution, &found, 1); miopen_find_options fo = nullptr;
auto result = miopen_solution{solution}; if(tune)
{
fo = make_obj<miopen_find_options>(&miopenCreateFindOptions);
miopenSetFindOptionTuning(fo.get(), 1);
}
auto status = miopenFindSolutions(handle, problem, fo.get(), &solution, &found, 1);
auto result = miopen_solution{solution};
if(status != miopenStatusSuccess or found == 0) if(status != miopenStatusSuccess or found == 0)
MIGRAPHX_THROW("MIOpen miopenFindSolutions failed"); MIGRAPHX_THROW("MIOpen miopenFindSolutions failed");
return result; return result;
......
...@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
if(relements > block_size * 256)
algo = "block_large";
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
} }
...@@ -166,7 +166,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -166,7 +166,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto reduce_elements = get_reduce_elements(ins->inputs()); auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type(); auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}"; v["reduction"] = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}"; std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half // Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384) if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})"; v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
......
...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l ...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_WARN(...) #define MIGRAPHX_WARN(...)
#endif #endif
#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
if constexpr(__VA_ARGS__)
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
...@@ -25,10 +25,6 @@ ...@@ -25,10 +25,6 @@
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP #define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC #ifndef MIGRAPHX_USE_HIPRTC
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/math_functions.h> #include <hip/math_functions.h>
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx { namespace migraphx {
...@@ -135,42 +136,100 @@ struct index ...@@ -135,42 +136,100 @@ struct index
return (n - _c<1>) / stride + _c<1>; return (n - _c<1>) / stride + _c<1>;
} }
template <class N>
constexpr auto max_global_stride_iterations(N n) const
{
return max_stride_iterations(n, nglobal());
}
template <class N>
constexpr auto max_local_stride_iterations(N n) const
{
return max_stride_iterations(n, nlocal());
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d))
{
return f(i, d);
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D) -> decltype(f(i))
{
return f(i);
}
template <class F, class N, class Stride> template <class F, class N, class Stride>
static constexpr void for_stride_loop_unroll(index_int start, N n, Stride stride, F f)
{
sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) {
auto i = start + stride * k;
if(i < n)
invoke_loop(f, i, d);
return d + _c<1>;
})(_c<0>, ks...);
});
}
template <class F, class N, class Stride>
static constexpr void for_stride_loop(index_int start, N n, Stride stride, F f)
{
index_int k = 0;
for(index_int i = start; i < n; i += stride)
{
invoke_loop(f, i, k);
k++;
}
}
template <bool Unroll, class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
MIGRAPHX_ASSERT(start < stride); MIGRAPHX_ASSERT(start < stride);
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and if constexpr(not is_integral<N>{} and not is_integral<Stride>{})
max_stride_iterations(n, stride) == 1)
{ {
if constexpr(stride > n) if constexpr(max_stride_iterations(n, stride) == 1)
{
if constexpr(stride > n)
{
if(start < n)
invoke_loop(f, start, _c<0>);
}
else
{
invoke_loop(f, start, _c<0>);
}
}
else if constexpr(Unroll)
{ {
if(start < n) MIGRAPHX_STATIC_ASSERT_FOR(max_stride_iterations(n, stride) < 256)
f(start); {
for_stride_loop_unroll(start, n, stride, f);
}
} }
else else
{ {
f(start); for_stride_loop(start, n, stride, f);
} }
} }
else else
{ {
for(index_int i = start; i < n; i += stride) for_stride_loop(start, n, stride, f);
{
f(i);
}
} }
} }
template <class F, class N> template <class F, class N>
__device__ void global_stride(N n, F f) const __device__ void global_stride(N n, F f) const
{ {
for_stride(global, n, nglobal(), f); for_stride<false>(global, n, nglobal(), f);
} }
template <class F, class N> template <class F, class N>
__device__ void local_stride(N n, F f) const __device__ void local_stride(N n, F f) const
{ {
for_stride(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
}; };
......
...@@ -46,28 +46,27 @@ template <index_int Axis, ...@@ -46,28 +46,27 @@ template <index_int Axis,
__device__ void generic_binary_layernorm( __device__ void generic_binary_layernorm(
F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto means = auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) { return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
auto x = op(x1, x2); })(input);
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps value_type eps_val = eps; // implicit conversion for eps
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto x = op(x1, x2);
auto m = x - mean_x; auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + epsilon) // m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...); y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...); })(output, input, inputs...);
}); });
} }
......
...@@ -66,13 +66,22 @@ struct convert_to ...@@ -66,13 +66,22 @@ struct convert_to
} }
}; };
template <index_int N>
struct mean struct mean
{ {
index_int item_num = 1;
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x) const
{ {
return x / static_cast<T>(item_num); using type = vec_type<T>;
if constexpr(is_floating_point<type>{})
{
constexpr type d = 1.0 / N;
return x * d;
}
else
{
return x / static_cast<type>(N);
}
} }
}; };
......
...@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
const auto ldsidx = idx.local / lanes_per_thread; const auto ldsidx = idx.local / lanes_per_thread;
...@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F> ...@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
...@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i) ...@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i)
namespace reduce { namespace reduce {
struct inner_storage_tag
{
};
template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class R, class F>
struct storage_access : F
{
using type = R;
};
template <class R, class F>
constexpr storage_access<R, F> make_storage_access(F f)
{
return {{f}};
}
template <class Slicer, class F> template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f) constexpr auto sliced(Slicer slicer, F f)
{ {
...@@ -191,20 +210,140 @@ constexpr auto compute_reduce_axis() ...@@ -191,20 +210,140 @@ constexpr auto compute_reduce_axis()
template <class Input, index_int Axis> template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>()); using with_axis = decltype(compute_reduce_axis<Input, Axis>());
template <class Derived>
struct reducer_base
{
template <class T>
__device__ auto make_inner_slice(T x) const
{
if constexpr(is_inner_storage<T>{})
{
return x;
}
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& {
return t[i];
});
}
}
template <class T, class... Ts>
constexpr auto get_size(T&& x, [[maybe_unused]] Ts&&... xs) const
{
MIGRAPHX_ASSERT(get_size(x) == get_size(xs...));
return get_size(x);
}
template <class T, class... Ts>
constexpr auto get_size(T&& x) const
{
if constexpr(is_inner_storage<T>{})
{
return x.rsize();
}
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return t.size();
}
}
template <class F>
__device__ auto inner_sliced(F f) const
{
return [=](auto&&... xs) { return f(get_size(xs...), make_inner_slice(xs)...); };
}
template <class T>
static __device__ typename T::type& decl_inner_storage(const T&);
template <class F>
__device__ auto inner(F f) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
using result_type = decltype(f(decl_inner_storage(xs)...));
auto&& derived = static_cast<const Derived&>(*this);
if constexpr(is_void<result_type>{})
{
derived.inner_void_impl(f, n, xs...);
}
else
{
return derived.template inner_impl<result_type>(f, n, xs...);
}
});
}
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
auto&& derived = static_cast<const Derived&>(*this);
return derived.reduce_impl(op, init, read, n, xs...);
});
}
template <class Op, class T>
__device__ auto reduce(Op op, T init) const
{
return this->reduce(op, init, op::id{});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class Input>
constexpr auto elements() const
{
auto&& derived = static_cast<const Derived&>(*this);
using reduce_type = decltype(derived.slice(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
}
};
struct block struct block
{ {
template <class Slicer> template <class Slicer>
struct reducer struct reducer : reducer_base<reducer<Slicer>>
{ {
index idx; index idx;
Slicer slice; Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const template <class T, index_int N, class Size>
struct inner_storage : inner_storage_tag
{
using type = T;
array<T, N> arr;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto& operator()(U, V d) const
{
return arr[d];
}
template <class U, class V>
constexpr auto& operator()(U, V d)
{
return arr[d];
}
};
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{ {
return sliced(slice, [=](auto x, auto... xs) { return block_reduce(idx, op, init, n, [&](auto j, auto d) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { return vec_reduce(read(xs(j, d)...), op);
return vec_reduce(read(x[j], xs[j]...), op);
});
}); });
} }
...@@ -215,31 +354,99 @@ struct block ...@@ -215,31 +354,99 @@ struct block
f(); f();
} }
template <class F> template <class F, class N, class... Ts>
__device__ auto inner(F f) const __device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
idx.local_stride(n, [&](auto j, auto d) { f(xs(j, d)...); });
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
return sliced(slice, [=](auto x, auto... xs) { using max_iterations = decltype(idx.max_local_stride_iterations(n));
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage;
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{{}, idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
struct block_large
{
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
{
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return block_reduce(idx, op, init, index_int{n}, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
}); });
} }
template <class Input> template <class F>
constexpr auto elements() const __device__ void outer(F f) const
{ {
using reduce_type = decltype(slice(Input{})); if(idx.local == 0)
using value_type = typename Input::type; f();
constexpr auto relements = get_shape_c<reduce_type>{}.elements(); }
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>(); template <class F, class N, class... Ts>
else __device__ void inner_void_impl(F f, N n, Ts&&... xs) const
return relements; {
idx.local_stride(index_int{n}, [&](auto j, auto d) { f(xs(j, d)...); });
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
} }
}; };
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
...@@ -257,22 +464,40 @@ struct block ...@@ -257,22 +464,40 @@ struct block
struct lane struct lane
{ {
template <class Slicer> template <class Slicer>
struct reducer struct reducer : reducer_base<reducer<Slicer>>
{ {
index idx; index idx;
Slicer slice; Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const template <class Size, class F>
struct inner_storage : inner_storage_tag
{ {
return sliced(slice, [=](auto x, auto... xs) { using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
using type = typename decltype(x)::type; F f;
type r = init; constexpr Size rsize() const { return {}; }
for(index_int j = 0; j < x.get_shape().elements(); j++) template <class U, class V>
{ constexpr auto operator()(U j, V d) const
r = op(r, read(x[j], xs[j]...)); {
} return f(j, d);
return r; }
}); };
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init;
for(index_int j = 0; j < n; j++)
{
r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
}
return r;
} }
template <class F> template <class F>
...@@ -281,29 +506,25 @@ struct lane ...@@ -281,29 +506,25 @@ struct lane
f(); f();
} }
template <class F> template <class F, class N, class... Ts>
__device__ auto inner(F f) const __device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{ {
return sliced(slice, [=](auto x, auto... xs) { for(index_int j = 0; j < n; j++)
for(index_int j = 0; j < x.get_shape().elements(); j++) {
{ f(xs(j, _c<0>)...);
f(x[j], xs[j]...); }
}
});
} }
template <class Input> template <class R, class F, class N, class... Ts>
constexpr auto elements() const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
using reduce_type = decltype(slice(Input{})); return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return get_shape_c<reduce_type>{}.elements();
} }
}; };
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
...@@ -318,6 +539,26 @@ struct lane ...@@ -318,6 +539,26 @@ struct lane
} }
}; };
// TODO: Remove these in the future when they can be selected in the compiler class
template <index_int RElements>
constexpr auto pick_block()
{
using nlocal = decltype(index{}.max_nlocal());
if constexpr(RElements < nlocal{} * 256)
return block{};
else
return block_large{};
}
template <index_int RElements>
using auto_block = decltype(pick_block<RElements>());
template <class Input, index_int Axis>
constexpr auto reduce_elements_with_axis()
{
constexpr auto s = get_shape_c<Input>{};
return s.lens[Axis];
}
} // namespace reduce } // namespace reduce
template <class Algo, template <class Algo,
......
...@@ -30,18 +30,20 @@ ...@@ -30,18 +30,20 @@
namespace migraphx { namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output) __device__ void softmax(Input input1, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input, Axis>()>;
block::template run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX #ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input)[0], 0); const auto c = vec_at(r.slice(input1)[0], 0);
#else #else
const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input); const auto c = r.reduce(op::max{}, lowest{}, op::id{})(input);
#endif #endif
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) { auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
return migraphx::convert<float>(migraphx::exp(x - c)); auto batch_sum =
})(input); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - c) / batch_sum; })(output, input); r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in);
}); });
} }
......
...@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible); ...@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible); MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible); MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template <class T>
struct remove_cv
{
using type = T;
};
template <class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
using remove_cv_t = typename remove_cv<T>::type;
template <class T> template <class T>
struct remove_reference struct remove_reference
{ {
...@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*> ...@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template <class T> template <class T>
using add_pointer_t = typename add_pointer<T>::type; using add_pointer_t = typename add_pointer<T>::type;
template <class T>
struct is_void : is_same<void, remove_cv_t<T>>
{
};
template <class... Ts> template <class... Ts>
struct common_type; struct common_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment