Unverified Commit 18cf0435 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Merge branch 'develop' into blas_tuning

parents 12258d8f 3e8d7196
......@@ -27,11 +27,14 @@
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/env.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins)
{
if(ins->name() == "contiguous")
......@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const
{
if(not literals[i].empty())
{
if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
{
std::cout << "Constant replace: " << std::endl;
std::vector<instruction_ref> inss;
fix([&](auto self, auto ins) {
if(contains(inss, ins))
return;
for(auto input : ins->inputs())
self(input);
inss.push_back(ins);
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
......
......@@ -62,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx {
migraphx::value to_value(py::kwargs kwargs);
......@@ -235,7 +236,8 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
py::class_<migraphx::shape> shape_cls(m, "shape");
shape_cls
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
......@@ -261,6 +263,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) {
......
......@@ -481,6 +481,15 @@ shape shape::with_type(type_t t) const
shape shape::to_dynamic() const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[](auto s) { return s.to_dynamic(); });
return {subs};
}
if(this->dynamic())
{
return *this;
......@@ -488,6 +497,30 @@ shape shape::to_dynamic() const
return {type(), lens(), lens(), {}};
}
shape shape::to_static(std::size_t x) const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[&](auto s) { return s.to_static(x); });
return {subs};
}
if(not this->dynamic())
{
return *this;
}
auto static_lens = this->max_lens();
std::transform(static_lens.begin(),
static_lens.end(),
this->dyn_dims().cbegin(),
static_lens.begin(),
[&](auto sl, auto dd) { return dd.is_fixed() ? sl : x; });
return {type(), static_lens};
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
......
......@@ -52,8 +52,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto conv_const_weights()
{
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
return match::name("convolution")(
match::used_once(),
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
}
auto reduction() { return match::name_contains("reduce"); }
......@@ -203,7 +204,12 @@ struct find_mul_slice_conv
}
};
// a * (x + b) => a * x + a * b
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.
struct find_mul_add
{
auto matcher() const
......@@ -268,6 +274,32 @@ struct find_dot_add
}
};
struct find_conv_add
{
auto matcher() const
{
auto add = match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("a")),
match::used_once());
return match::name("convolution")(match::used_once(),
match::args(add, match::is_constant().bind("w")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto x_ins = r.instructions["x"];
auto w_ins = r.instructions["w"];
auto conv1 = m.insert_instruction(ins, ins->get_operator(), a_ins, w_ins);
auto conv2 = m.insert_instruction(ins, ins->get_operator(), x_ins, w_ins);
m.replace_instruction(ins, make_op("add"), conv1, conv2);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
......@@ -1239,6 +1271,7 @@ void simplify_algebra::apply(module& m) const
find_neg_unit_ops{},
find_zero_ops{},
find_dot_add{},
find_conv_add{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
......
......@@ -28,6 +28,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -67,6 +68,37 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max};
}
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/**
* Makes all the shapes in the dynamic_dimension range.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
......@@ -97,6 +129,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod);
}
// redirect to select_module operator and return
......
......@@ -82,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
simplify_algebra{},
simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{},
......
......@@ -33,7 +33,11 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
if(BUILD_DEV)
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif()
include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
......
......@@ -168,7 +168,7 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", ");
}
std::string generate_pointwise(const module& pm, const std::string& name)
void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{
module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
......@@ -184,8 +184,131 @@ std::string generate_pointwise(const module& pm, const std::string& name)
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function(
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name));
gg.create_function(g.generate_module(m)
.set_attributes({"__device__", "__attribute__((const))"})
.set_generic_types(m)
.set_name(name));
}
std::string generate_pointwise(const module& pm, const std::string& name)
{
cpp_generator g;
generate_pointwise(g, pm, name);
return g.str();
}
std::string reduce_op::str() const
{
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
}
void reduce_op::set(instruction_ref ins, const operation& op)
{
if(op.name() == "reduce_sum")
{
reduction = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
auto s = ins->inputs().front()->get_shape();
auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
reduction = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
read = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type}, reduce_type))
read = mean;
else
write = mean;
}
else if(op.name() == "reduce_max")
{
reduction = "op::max{}";
init = "lowest{}";
}
else if(op.name() == "reduce_min")
{
reduction = "op::min{}";
init = "highest{}";
}
else if(op.name() == "reduce_prod")
{
reduction = "op::product{}";
init = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
}
std::string reduce_op::generate(instruction_ref ins, const std::string& x)
{
reduce_op r{x};
r.set(ins, ins->get_operator());
return r.str();
}
static bool use_lazy_inner(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
auto output = ins->outputs().front();
return contains(output->name(), "reduce") or output->name() == "@return";
}
std::string generate_reduce(const module& m, const std::string& name)
{
cpp_generator g;
auto ilens = m.get_parameter_shapes().begin()->second.lens();
std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce"))
{
return reduce_op::generate(ins, names.at(ins->inputs().front()));
}
else if(ins->name() == "pointwise")
{
auto pointwise_name = "pointwise" + std::to_string(i);
i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
std::vector<instruction_ref> tensors;
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(tensors),
[&](auto input) {
return input->get_shape().lens() == ilens and
not input->get_shape().broadcasted();
});
auto inner_names = names;
for(auto input : tensors)
inner_names[input] += "_lambda_param";
auto call_function =
pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if(tensors.empty())
return call_function;
const std::string inner_template =
"r.${inner}([=](${params}) { return ${call}; })(${args})";
std::string inner_name = use_lazy_inner(ins) ? "lazy_inner" : "inner";
auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names);
std::transform(
params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
return interpolate_string(inner_template,
{{"inner", inner_name},
{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
else if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
MIGRAPHX_THROW("Unknown operator: " + ins->name());
});
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r");
g.create_function(f);
return g.str();
}
......@@ -196,7 +319,17 @@ static std::vector<std::string> get_op_names(const module& m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
if(ins.name() == "multibroadcast")
continue;
if(ins.name() == "pointwise")
{
auto names = get_op_names(*ins.module_inputs().front());
result.insert(result.end(), names.begin(), names.end());
}
else
{
result.push_back(ins.name());
}
}
return result;
}
......
......@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable(gpu-driver
${GPU_DRIVER_SRCS}
)
rocm_clang_tidy_check(gpu-driver)
target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
......@@ -44,7 +44,7 @@ struct auto_register_action
template <class T>
static void apply()
{
auto name = get_type_name<T>();
const auto& name = get_type_name<T>();
register_action(name.substr(name.rfind("::") + 2),
[](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); });
}
......
......@@ -189,8 +189,20 @@ argument register_on_gpu(const argument& arg)
argument to_gpu(const argument& arg, bool host)
{
auto p = write_to_gpu(arg.data(), arg.get_shape().bytes(), host);
return {arg.get_shape(), p};
argument result;
arg.visit(
[&](auto x) {
auto p = write_to_gpu(arg.data(), arg.get_shape().bytes(), host);
result = {x.get_shape(), p};
},
[&](const auto& xs) {
std::vector<argument> args;
std::transform(xs.begin(), xs.end(), std::back_inserter(args), [&](auto x) {
return to_gpu(x, host);
});
result = argument{args};
});
return result;
}
argument from_gpu(const argument& arg)
......
......@@ -26,6 +26,7 @@
#include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -34,6 +35,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct shape;
struct operation;
namespace gpu {
......@@ -72,8 +74,23 @@ std::string make_transformer_args(Ts... xs)
std::string generate_pointwise(const module& pm, const std::string& name);
std::string generate_reduce(const module& m, const std::string& name);
std::string generate_name_from_ops(const module& m);
struct reduce_op
{
std::string input = "";
std::string reduction = "";
std::string init = "0";
std::string read = "op::id{}";
std::string write = "op::id{}";
void set(instruction_ref ins, const operation& op);
std::string str() const;
static std::string generate(instruction_ref ins, const std::string& x);
};
} // namespace gen
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -71,6 +71,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024);
std::string generate_make_shape(const shape& s);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_CONVOLUTION_HPP
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
......
......@@ -60,15 +60,6 @@ __global__ void reduce_kernel(void* input_p, void* output_p)
)__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
......@@ -86,9 +77,28 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>&
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes)
lens[axis] = s.lens()[axis];
return shape{s.type(), lens};
}
template <class T>
static shape get_output_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
for(const auto& axis : axes)
lens[axis] = 1;
return shape{s.type(), lens};
}
template <class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
......@@ -103,11 +113,27 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
return "block";
}
struct reduce_compiler : compiler<reduce_compiler>
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
return get_reduce_algo(inputs, rlens);
}
struct simple_reduce_compiler : compiler<simple_reduce_compiler>
{
std::vector<std::string> names() const
{
return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"};
return {"simple_reduce",
"reduce_sum",
"reduce_mean",
"reduce_max",
"reduce_min",
"reduce_prod"};
}
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
......@@ -157,44 +183,108 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
value v = value::object{};
if(op.name() == "reduce_sum")
{
v["reduction"] = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type},
reduce_type))
v["read"] = mean;
else
v["write"] = mean;
}
else if(op.name() == "reduce_max")
{
v["reduction"] = "op::max{}";
v["init"] = "lowest{}";
}
else if(op.name() == "reduce_min")
reduce_op r{};
r.set(ins, op);
v["reduction"] = r.reduction;
v["read"] = r.read;
v["write"] = r.write;
v["init"] = r.init;
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
static const char* const fused_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
std::vector<std::string> names() const { return {"fused_reduce"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto axes = v.at("axes").to_vector<std::size_t>();
auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs);
auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back();
virtual_inputs.pop_back();
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = virtual_inputs;
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduction_shape.lens()));
if(algo == "block")
{
v["reduction"] = "op::min{}";
v["init"] = "highest{}";
// Vectorize if the axis is a reduction axis
if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size;
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else if(op.name() == "reduce_prod")
else if(algo == "lane")
{
v["reduction"] = "op::product{}";
v["init"] = "1";
options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.kernel_name = v.get("kernel", "reduce_kernel");
auto src = interpolate_string(
fused_reduce_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(not ins->module_inputs().empty());
auto v = op.to_value();
auto* rm = ins->module_inputs().front();
v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
......
......@@ -195,6 +195,14 @@ constexpr auto compose(Fs... fs)
})(fs...);
}
template <class F>
constexpr auto partial(F f)
{
return [=](auto... xs) {
return [=](auto&&... ys) { return f(xs..., static_cast<decltype(ys)>(ys)...); };
};
}
template <class... Ts>
constexpr auto pack(Ts... xs)
{
......
......@@ -233,6 +233,12 @@ struct index
}
};
#ifdef MIGRAPHX_NLOCAL
#define MIGRAPHX_GLOBAL \
__global__ __attribute__((amdgpu_flat_work_group_size(MIGRAPHX_NLOCAL, MIGRAPHX_NLOCAL)))
#else
#define MIGRAPHX_GLOBAL __global__
#endif
inline __device__ __attribute__((const)) index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
......
......@@ -174,6 +174,25 @@ struct inner_storage_tag
template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class Size, class F>
struct lazy_inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
constexpr lazy_inner_storage<Size, F> make_lazy_inner_storage(Size, F f)
{
return {{}, f};
}
template <class R, class F>
struct storage_access : F
{
......@@ -278,6 +297,14 @@ struct reducer_base
});
}
template <class F>
__device__ auto lazy_inner(F f) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
});
}
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
......@@ -396,25 +423,6 @@ struct block_large
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
......@@ -439,7 +447,7 @@ struct block_large
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
......@@ -469,25 +477,6 @@ struct lane
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
......@@ -518,7 +507,7 @@ struct lane
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
template <class Slicer>
......@@ -577,5 +566,21 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu
});
}
template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f)
{
Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r);
if constexpr(reduce::is_inner_storage<decltype(result)>{})
{
r.inner([&](auto& y, auto x) { y = x; })(output, result);
}
else
{
r.outer([&] { output[out_idx] = implicit_conversion(result); });
}
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
......@@ -197,10 +197,14 @@ struct mlir_program
result = mlirF64TypeGet(ctx.get());
else if(as.is_integral())
{
if(as.is_signed())
result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8);
else
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
// Note: rocMLIR use signless integer type for tensors types. This
// will translate to signed implementation for current supported
// operations.
if(as.is_unsigned())
{
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
}
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
}
else
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
......@@ -483,7 +487,7 @@ struct mlir_program
static value get_operator_value(const operation& op)
{
auto v = op.to_value();
if(op.name() == "convolution")
if(op.name() == "convolution" or op.name() == "quant_convolution")
{
// Adjust symetrical padding
if(v.at("padding").size() == v.at("stride").size())
......
......@@ -32,6 +32,7 @@
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
......@@ -72,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass
{
......@@ -129,6 +131,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{},
fuse_mlir{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment