Commit baac1dab authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-host-lib

parents 830dff7a 77042e30
...@@ -72,6 +72,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -72,6 +72,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024); std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024);
std::string generate_make_shape(const shape& s);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
...@@ -215,6 +216,10 @@ struct context ...@@ -215,6 +216,10 @@ struct context
return *current_device; return *current_device;
} }
bool get_exhaustive_tune_flag() const { return exhaustive_tune; }
void set_exhaustive_tune_flag(bool t) { exhaustive_tune = t; }
hip_device::stream& get_stream() { return get_current_device().get_stream(); } hip_device::stream& get_stream() { return get_current_device().get_stream(); }
hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); } hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); }
...@@ -273,7 +278,8 @@ struct context ...@@ -273,7 +278,8 @@ struct context
auto v_streams = v.at("streams"); auto v_streams = v.at("streams");
std::size_t n_streams = v_streams.without_key().to<std::size_t>(); std::size_t n_streams = v_streams.without_key().to<std::size_t>();
this->current_device = std::make_shared<hip_device>(0, n_streams); auto device = get_device_id();
this->current_device = std::make_shared<hip_device>(device, n_streams);
} }
void wait_for(any_ptr queue) void wait_for(any_ptr queue)
...@@ -336,7 +342,8 @@ struct context ...@@ -336,7 +342,8 @@ struct context
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events; std::vector<shared<hip_event_ptr>> events;
bool measure_perf = false; bool exhaustive_tune = false;
bool measure_perf = false;
// for event perf timing // for event perf timing
shared<hip_event_ptr> start_event = nullptr; shared<hip_event_ptr> start_event = nullptr;
shared<hip_event_ptr> stop_event = nullptr; shared<hip_event_ptr> stop_event = nullptr;
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
...@@ -175,8 +174,9 @@ struct miopen_convolution ...@@ -175,8 +174,9 @@ struct miopen_convolution
auto* miopen_stream_handle = ctx.get_stream().get_miopen(); auto* miopen_stream_handle = ctx.get_stream().get_miopen();
solution_ptr = find_solution(miopen_stream_handle, conv_problem.get()); solution_ptr = find_solution(
auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size); miopen_stream_handle, conv_problem.get(), ctx.get_exhaustive_tune_flag());
auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size"); MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size");
...@@ -233,7 +233,7 @@ struct miopen_convolution ...@@ -233,7 +233,7 @@ struct miopen_convolution
&perf, &perf,
workspace.implicit(), workspace.implicit(),
workspace_size, workspace_size,
false); ctx.get_exhaustive_tune_flag());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo; algo = perf.fwd_algo;
......
...@@ -34,6 +34,8 @@ struct module_pass_manager; ...@@ -34,6 +34,8 @@ struct module_pass_manager;
namespace gpu { namespace gpu {
bool mlir_enabled();
struct fuse_mlir struct fuse_mlir
{ {
context* ctx = nullptr; context* ctx = nullptr;
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/dyn_output.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -98,6 +99,13 @@ struct hip_sync_stream ...@@ -98,6 +99,13 @@ struct hip_sync_stream
return {}; return {};
return args.front(); return args.front();
} }
std::ptrdiff_t output_alias(const std::vector<shape>& args) const
{
if(args.empty())
return -1;
return 0;
}
}; };
struct hip_copy_to_gpu struct hip_copy_to_gpu
...@@ -105,7 +113,7 @@ struct hip_copy_to_gpu ...@@ -105,7 +113,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -114,6 +122,10 @@ struct hip_copy_to_gpu ...@@ -114,6 +122,10 @@ struct hip_copy_to_gpu
if(args.size() == 1) if(args.size() == 1)
return input; return input;
argument result = args[1].share(); argument result = args[1].share();
if(result.get_shape().dynamic())
{
result = result.reshape(args[0].get_shape());
}
gpu_copy(ctx, input, result); gpu_copy(ctx, input, result);
// Associate the input since it was registered with hip // Associate the input since it was registered with hip
return {result.get_shape(), [input, result]() mutable { return result.data(); }}; return {result.get_shape(), [input, result]() mutable { return result.data(); }};
...@@ -131,19 +143,24 @@ struct hip_copy_from_gpu ...@@ -131,19 +143,24 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const dyn_output& dyn_out, const std::vector<argument>& args) const
{ {
if(args.size() == 1) if(args.size() == 1)
{ {
argument result = allocate_gpu(output_shape, true); argument result = allocate_gpu(dyn_out.computed_shape, true);
gpu_copy(ctx, args[0], result); gpu_copy(ctx, args[0], result);
return result; return result;
} }
copy_from_gpu(ctx, args[0], args[1]); argument input = args[0].share();
if(input.get_shape().dynamic())
{
input = input.reshape(args[1].get_shape());
}
copy_from_gpu(ctx, input, args[1]);
return args[1]; return args[1];
} }
std::ptrdiff_t output_alias(const std::vector<shape>& args) const std::ptrdiff_t output_alias(const std::vector<shape>& args) const
......
...@@ -33,6 +33,13 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,6 +33,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
namespace gpu { namespace gpu {
/**
* Compiler pass that makes GPU-specific instruction changes.
* * Copies to and from the device if `offload_copy` is true.
* * Maps instructions to their GPU-specific counterparts.
* * Inserts `allocate` instructions before GPU operators.
*/
struct lowering struct lowering
{ {
context* ctx; context* ctx;
......
...@@ -75,12 +75,19 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr ...@@ -75,12 +75,19 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr
using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem); using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem);
using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution); using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution);
inline miopen_solution find_solution(miopenHandle_t handle, miopenProblem_t problem) inline miopen_solution
find_solution(miopenHandle_t handle, miopenProblem_t problem, bool tune = false)
{ {
miopenSolution_t solution; miopenSolution_t solution;
size_t found = 0; size_t found = 0;
auto status = miopenFindSolutions(handle, problem, nullptr, &solution, &found, 1); miopen_find_options fo = nullptr;
auto result = miopen_solution{solution}; if(tune)
{
fo = make_obj<miopen_find_options>(&miopenCreateFindOptions);
miopenSetFindOptionTuning(fo.get(), 1);
}
auto status = miopenFindSolutions(handle, problem, fo.get(), &solution, &found, 1);
auto result = miopen_solution{solution};
if(status != miopenStatusSuccess or found == 0) if(status != miopenStatusSuccess or found == 0)
MIGRAPHX_THROW("MIOpen miopenFindSolutions failed"); MIGRAPHX_THROW("MIOpen miopenFindSolutions failed");
return result; return result;
......
...@@ -56,7 +56,6 @@ struct oper ...@@ -56,7 +56,6 @@ struct oper
return name.substr(pos_ns + 2); return name.substr(pos_ns + 2);
} }
} }
return "unknown_operator_name"; return "unknown_operator_name";
} }
}; };
......
...@@ -31,7 +31,6 @@ ...@@ -31,7 +31,6 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/type_name.hpp>
#include <utility> #include <utility>
#include <iostream> #include <iostream>
......
...@@ -33,7 +33,6 @@ ...@@ -33,7 +33,6 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <utility> #include <utility>
#include <iostream> #include <iostream>
......
...@@ -31,7 +31,6 @@ ...@@ -31,7 +31,6 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <utility> #include <utility>
#include <iostream> #include <iostream>
......
...@@ -37,7 +37,6 @@ struct target ...@@ -37,7 +37,6 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options& options) const; std::vector<pass> get_passes(migraphx::context& gctx, const compile_options& options) const;
migraphx::context get_context() const; migraphx::context get_context() const;
argument copy_to(const argument& arg) const; argument copy_to(const argument& arg) const;
argument copy_from(const argument& arg) const; argument copy_from(const argument& arg) const;
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
......
...@@ -78,7 +78,9 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -78,7 +78,9 @@ struct concat_compiler : compiler<concat_compiler>
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel"); options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs); auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(ctx, axis, options.inputs); vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string( auto src = interpolate_string(
......
...@@ -32,7 +32,7 @@ namespace gpu { ...@@ -32,7 +32,7 @@ namespace gpu {
struct mlir_compiler : compiler<mlir_compiler> struct mlir_compiler : compiler<mlir_compiler>
{ {
std::vector<std::string> names() const { return {"gpu::mlir_conv"}; } std::vector<std::string> names() const { return {"gpu::mlir_op"}; }
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
......
...@@ -60,15 +60,6 @@ __global__ void reduce_kernel(void* input_p, void* output_p) ...@@ -60,15 +60,6 @@ __global__ void reduce_kernel(void* input_p, void* output_p)
)__migraphx__"; )__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens, static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens) const std::vector<std::size_t>& output_lens)
{ {
...@@ -86,9 +77,28 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& ...@@ -86,9 +77,28 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>&
return reduce_lens; return reduce_lens;
} }
static std::string get_reduce_algo(const std::vector<shape>& inputs) template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes)
lens[axis] = s.lens()[axis];
return shape{s.type(), lens};
}
template <class T>
static shape get_output_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
for(const auto& axis : axes)
lens[axis] = 1;
return shape{s.type(), lens};
}
template <class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens)
{ {
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max(); const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride // The minimum stride
auto min_stride = std::inner_product( auto min_stride = std::inner_product(
...@@ -103,11 +113,27 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs) ...@@ -103,11 +113,27 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
return "block"; return "block";
} }
struct reduce_compiler : compiler<reduce_compiler> static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
return get_reduce_algo(inputs, rlens);
}
struct simple_reduce_compiler : compiler<simple_reduce_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
{ {
return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"}; return {"simple_reduce",
"reduce_sum",
"reduce_mean",
"reduce_max",
"reduce_min",
"reduce_prod"};
}
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
...@@ -127,7 +153,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -127,7 +153,7 @@ struct reduce_compiler : compiler<reduce_compiler>
vec = vectorize::elements(ctx, faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
if(relements > block_size * 256) if(relements >= block_size * 256)
algo = "block_large"; algo = "block_large";
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
...@@ -157,44 +183,108 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -157,44 +183,108 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
value v = value::object{}; value v = value::object{};
if(op.name() == "reduce_sum") reduce_op r{};
{ r.set(ins, op);
v["reduction"] = "op::sum{}"; v["reduction"] = r.reduction;
} v["read"] = r.read;
else if(op.name() == "reduce_mean") v["write"] = r.write;
{ v["init"] = r.init;
auto reduce_elements = get_reduce_elements(ins->inputs()); return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
auto reduce_type = ins->inputs().front()->get_shape().type(); }
v["reduction"] = "op::sum{}"; };
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half static const char* const fused_reduce_kernel = R"__migraphx__(
if(reduce_type == shape::half_type and reduce_elements > 16384) #include <migraphx/kernels/index.hpp>
v["read"] = "compose(" + mean + ", op::convert_to<float>{})"; #include <migraphx/kernels/reduce.hpp>
else if(contains({shape::float_type, shape::half_type, shape::double_type}, #include <migraphx/kernels/pointwise.hpp>
reduce_type)) #include <migraphx/kernels/vectorize.hpp>
v["read"] = mean; #include <args.hpp>
else
v["write"] = mean; namespace migraphx {
}
else if(op.name() == "reduce_max") ${preamble}
{
v["reduction"] = "op::max{}"; extern "C" {
v["init"] = "lowest{}"; MIGRAPHX_GLOBAL void ${kernel}(${params})
} {
else if(op.name() == "reduce_min") transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
std::vector<std::string> names() const { return {"fused_reduce"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto axes = v.at("axes").to_vector<std::size_t>();
auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs);
auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back();
virtual_inputs.pop_back();
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = virtual_inputs;
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduction_shape.lens()));
if(algo == "block")
{ {
v["reduction"] = "op::min{}"; // Vectorize if the axis is a reduction axis
v["init"] = "highest{}"; if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size;
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
} }
else if(op.name() == "reduce_prod") else if(algo == "lane")
{ {
v["reduction"] = "op::product{}"; options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
v["init"] = "1";
} }
else else
{ {
MIGRAPHX_THROW("Unsupported reduce"); MIGRAPHX_THROW("Unknown reduce algo: " + algo);
} }
options.kernel_name = v.get("kernel", "reduce_kernel");
auto src = interpolate_string(
fused_reduce_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(not ins->module_inputs().empty());
auto v = op.to_value();
auto* rm = ins->module_inputs().front();
v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
} }
}; };
......
...@@ -204,6 +204,14 @@ constexpr auto compose(Fs... fs) ...@@ -204,6 +204,14 @@ constexpr auto compose(Fs... fs)
})(fs...); })(fs...);
} }
template <class F>
constexpr auto partial(F f)
{
return [=](auto... xs) {
return [=](auto&&... ys) { return f(xs..., static_cast<decltype(ys)>(ys)...); };
};
}
template <class... Ts> template <class... Ts>
constexpr auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
......
...@@ -25,17 +25,9 @@ ...@@ -25,17 +25,9 @@
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP #define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC #ifndef MIGRAPHX_USE_HIPRTC
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/math_functions.h> #include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif #endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP #endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
...@@ -241,6 +241,12 @@ struct index ...@@ -241,6 +241,12 @@ struct index
} }
}; };
#ifdef MIGRAPHX_NLOCAL
#define MIGRAPHX_GLOBAL \
__global__ __attribute__((amdgpu_flat_work_group_size(MIGRAPHX_NLOCAL, MIGRAPHX_NLOCAL)))
#else
#define MIGRAPHX_GLOBAL __global__
#endif
inline __device__ __attribute__((const)) index make_index() inline __device__ __attribute__((const)) index make_index()
{ {
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
......
...@@ -48,12 +48,20 @@ __device__ void generic_binary_layernorm( ...@@ -48,12 +48,20 @@ __device__ void generic_binary_layernorm(
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>; using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
block::template run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_type<value_type>{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) { auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing higher values
// before overflow in low precision
auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input); })(input);
auto mean_x = means[0]; auto mean_x = means[0];
......
...@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor) ...@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload // Use float to compute half overload
...@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) ...@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names // packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented // Most but not all of these math ops have operators of the same names.
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2) MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2)
MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil) MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil)
MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos) MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos)
...@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log) ...@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10) MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2) MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt) MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin) MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt) MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
template <class T, class U> template <class T, class U>
...@@ -189,7 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max) ...@@ -189,7 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
// Add overloads for half that calls the float version // Add overloads for half that calls the float version, this should use "hmax" and "hmin" once
// perf CI docker is upgraded to rocm-5.5
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
...@@ -217,14 +217,6 @@ constexpr auto min(const T& a, const U& b) ...@@ -217,14 +217,6 @@ constexpr auto min(const T& a, const U& b)
return min<common_type_t<T, U>>(a, b); return min<common_type_t<T, U>>(a, b);
} }
// Sin for half is broken on hip, so use cos instead
template <class T, MIGRAPHX_REQUIRES(is_same<vec_type<T>, half>{})>
constexpr T sin(T x)
{
constexpr const T shift = HIP_PIO2_F;
return migraphx::cos(shift - x);
}
MIGRAPHX_DEVICE_MATH_VEC(abs) MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos) MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh) MIGRAPHX_DEVICE_MATH_VEC(acosh)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment