Unverified Commit 4a4f537e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into gpu-batch-gemm-bert

parents adc03be6 a27dd28c
...@@ -28,30 +28,30 @@ struct hip_stream_model ...@@ -28,30 +28,30 @@ struct hip_stream_model
bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; } bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; }
}; };
stream_model make_stream_model(const module& p) stream_model make_stream_model(const module& m)
{ {
hip_stream_model m; hip_stream_model hsm;
std::size_t stream = 0; std::size_t stream = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "gpu::set_stream") if(ins->name() == "gpu::set_stream")
{ {
auto v = ins->get_operator().to_value(); auto v = ins->get_operator().to_value();
stream = v["stream"].to<std::size_t>(); stream = v["stream"].to<std::size_t>();
m.max_stream = std::max(stream, m.max_stream); hsm.max_stream = std::max(stream, hsm.max_stream);
} }
if(ins->get_operator().is_context_free()) if(ins->get_operator().is_context_free())
continue; continue;
if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name())) if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name()))
continue; continue;
m.ins2stream[ins] = stream; hsm.ins2stream[ins] = stream;
} }
return m; return hsm;
} }
std::vector<stream_race> analyze_streams(const module& p) std::vector<stream_race> analyze_streams(const module& m)
{ {
return migraphx::analyze_streams(p, make_stream_model(p)); return migraphx::analyze_streams(m, make_stream_model(m));
} }
} // namespace gpu } // namespace gpu
......
...@@ -22,6 +22,7 @@ namespace gpu { ...@@ -22,6 +22,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC #if MIGRAPHX_USE_HIPRTC
...@@ -133,6 +134,7 @@ struct hiprtc_program ...@@ -133,6 +134,7 @@ struct hiprtc_program
std::vector<char> buffer(n); std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() == 0); assert(buffer.back() == 0);
// cppcheck-suppress returnDanglingLifetime
return {buffer.begin(), buffer.end() - 1}; return {buffer.begin(), buffer.end() - 1};
} }
...@@ -246,6 +248,16 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -246,6 +248,16 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
MIGRAPHX_THROW("Missing hsaco"); MIGRAPHX_THROW("Missing hsaco");
}; };
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
if(enabled(MIGRAPHX_GPU_DUMP_ASM{})) if(enabled(MIGRAPHX_GPU_DUMP_ASM{}))
{ {
......
...@@ -11,11 +11,11 @@ namespace migraphx { ...@@ -11,11 +11,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void eliminate_workspace::apply(module& p) const void eliminate_workspace::apply(module& m) const
{ {
std::size_t n = 0; std::size_t n = 0;
std::vector<instruction_ref> allocs; std::vector<instruction_ref> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->outputs().size() != 1) if(ins->outputs().size() != 1)
continue; continue;
...@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const ...@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const
} }
if(n > 0) if(n > 0)
{ {
auto ws = p.add_parameter("workspace", shape{shape::int8_type, {n}}); auto ws = m.add_parameter("workspace", shape{shape::int8_type, {n}});
for(auto&& a : allocs) for(auto&& a : allocs)
{ {
p.replace_instruction(a, ws); m.replace_instruction(a, ws);
p.remove_instruction(a); m.remove_instruction(a);
} }
} }
} }
......
...@@ -316,7 +316,7 @@ struct find_layernorm ...@@ -316,7 +316,7 @@ struct find_layernorm
{ {
auto matcher() const { return match::layernorm(&gpu_name); } auto matcher() const { return match::layernorm(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -331,7 +331,7 @@ struct find_layernorm ...@@ -331,7 +331,7 @@ struct find_layernorm
if(relements > 1024 or (relements % 4 != 0 and relements > 256)) if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return; return;
p.replace_instruction(ins, hip_layernorm{}, x_ins, args.back()); m.replace_instruction(ins, hip_layernorm{}, x_ins, args.back());
} }
}; };
...@@ -343,11 +343,11 @@ struct find_triadd_layernorm ...@@ -343,11 +343,11 @@ struct find_triadd_layernorm
match::used_once(), match::all_of[match::inputs()](match::standard_shape())))); match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto triadd = ins->inputs().front(); auto triadd = ins->inputs().front();
p.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs()); m.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs());
} }
}; };
...@@ -355,13 +355,13 @@ struct find_gelu ...@@ -355,13 +355,13 @@ struct find_gelu
{ {
auto matcher() const { return match::gelu_erf(&gpu_name); } auto matcher() const { return match::gelu_erf(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
} }
}; };
...@@ -372,7 +372,7 @@ struct find_add_gelu ...@@ -372,7 +372,7 @@ struct find_add_gelu
return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -381,7 +381,7 @@ struct find_add_gelu ...@@ -381,7 +381,7 @@ struct find_add_gelu
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu{}, args); m.replace_instruction(ins, hip_add_gelu{}, args);
} }
}; };
...@@ -391,16 +391,16 @@ struct find_gelu_new ...@@ -391,16 +391,16 @@ struct find_gelu_new
auto matcher() const { return match::gelu_tanh(&gpu_name); } auto matcher() const { return match::gelu_tanh(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
if(fast_math) if(fast_math)
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
else else
p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
} }
}; };
...@@ -411,7 +411,7 @@ struct find_add_gelu_new ...@@ -411,7 +411,7 @@ struct find_add_gelu_new
return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -420,7 +420,7 @@ struct find_add_gelu_new ...@@ -420,7 +420,7 @@ struct find_add_gelu_new
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu_new{}, args); m.replace_instruction(ins, hip_add_gelu_new{}, args);
} }
}; };
...@@ -435,7 +435,7 @@ struct find_add_clip ...@@ -435,7 +435,7 @@ struct find_add_clip
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -448,9 +448,9 @@ struct find_add_clip ...@@ -448,9 +448,9 @@ struct find_add_clip
add_args.pop_back(); add_args.pop_back();
add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end()); add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_clip{}, add_args); m.replace_instruction(ins, hip_add_clip{}, add_args);
else if(add_ins->name() == "gpu::triadd") else if(add_ins->name() == "gpu::triadd")
p.replace_instruction(ins, hip_triadd_clip{}, add_args); m.replace_instruction(ins, hip_triadd_clip{}, add_args);
} }
}; };
...@@ -470,7 +470,7 @@ struct find_add_unary ...@@ -470,7 +470,7 @@ struct find_add_unary
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -481,9 +481,9 @@ struct find_add_unary ...@@ -481,9 +481,9 @@ struct find_add_unary
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, binary_add_op, args); m.replace_instruction(ins, binary_add_op, args);
else if(add_ins->name() == "gpu::triadd") else if(add_ins->name() == "gpu::triadd")
p.replace_instruction(ins, ternary_add_op, args); m.replace_instruction(ins, ternary_add_op, args);
} }
}; };
...@@ -498,7 +498,7 @@ struct find_triadd ...@@ -498,7 +498,7 @@ struct find_triadd
.bind("input"))); .bind("input")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
...@@ -513,7 +513,7 @@ struct find_triadd ...@@ -513,7 +513,7 @@ struct find_triadd
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args); m.replace_instruction(ins, hip_triadd{}, args);
} }
}; };
...@@ -525,7 +525,7 @@ struct find_mul_add ...@@ -525,7 +525,7 @@ struct find_mul_add
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b"))); match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto mul_ins = r.instructions["mul"]; auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
...@@ -538,7 +538,7 @@ struct find_mul_add ...@@ -538,7 +538,7 @@ struct find_mul_add
args.insert(std::prev(args.end()), b_ins); args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args); m.replace_instruction(ins, hip_mul_add{}, args);
} }
}; };
...@@ -550,7 +550,7 @@ struct find_mul_add_relu ...@@ -550,7 +550,7 @@ struct find_mul_add_relu
match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add"))); match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto mul_add_ins = r.instructions["mul_add"]; auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result; auto ins = r.result;
...@@ -558,7 +558,7 @@ struct find_mul_add_relu ...@@ -558,7 +558,7 @@ struct find_mul_add_relu
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args); m.replace_instruction(ins, hip_mul_add_relu{}, args);
} }
}; };
...@@ -783,7 +783,7 @@ auto conv_bias(Ms... ms) ...@@ -783,7 +783,7 @@ auto conv_bias(Ms... ms)
} }
template <class Op> template <class Op>
void apply_conv_bias(context& ctx, module& p, match::matcher_result r) void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -798,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r) ...@@ -798,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.get_workspace(ctx); auto ws = cb.get_workspace(ctx);
(void)ws; (void)ws;
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT inline auto precompile_name(std::string s) // NOLINT
...@@ -829,9 +829,9 @@ struct find_conv_bias ...@@ -829,9 +829,9 @@ struct find_conv_bias
match::output(match::name(std::unordered_set<std::string>{"gpu::relu"})))); match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias>(*ctx, m, r);
} }
}; };
...@@ -840,9 +840,9 @@ struct find_conv_bias_relu ...@@ -840,9 +840,9 @@ struct find_conv_bias_relu
context* ctx = nullptr; context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); } auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias_relu>(*ctx, m, r);
} }
}; };
...@@ -857,7 +857,7 @@ struct find_conv_pointwise ...@@ -857,7 +857,7 @@ struct find_conv_pointwise
fusable_conv(match::used_once()).bind("conv"))); fusable_conv(match::used_once()).bind("conv")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -896,7 +896,7 @@ struct find_gemm_add ...@@ -896,7 +896,7 @@ struct find_gemm_add
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm_ins = r.instructions["gemm"]; auto gemm_ins = r.instructions["gemm"];
...@@ -919,15 +919,15 @@ struct find_gemm_add ...@@ -919,15 +919,15 @@ struct find_gemm_add
auto copy_ins = c_ins; auto copy_ins = c_ins;
// Insert copy // Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty()) if(ins == m.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{ {
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back()); copy_ins = m.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
} }
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
gemm.beta = 1; gemm.beta = 1;
p.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
} }
}; };
...@@ -938,22 +938,22 @@ struct find_commutative_broadcast ...@@ -938,22 +938,22 @@ struct find_commutative_broadcast
return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape())); return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape()));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto args = ins->inputs(); auto args = ins->inputs();
move_broadcasted_back(args); move_broadcasted_back(args);
p.replace_instruction(ins, ins->get_operator(), args); m.replace_instruction(ins, ins->get_operator(), args);
} }
}; };
void fuse_ops::apply(module& p) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(p, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_gelu{}, find_gelu_new{fast_math});
run_passes(p, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(p, find_triadd{}); match::find_matches(m, find_triadd{});
match::find_matches(p, match::find_matches(m,
find_layernorm{}, find_layernorm{},
find_conv_pointwise{ctx}, find_conv_pointwise{ctx},
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
...@@ -966,8 +966,8 @@ void fuse_ops::apply(module& p) const ...@@ -966,8 +966,8 @@ void fuse_ops::apply(module& p) const
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}}, find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(p, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{}); match::find_matches(m, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -11,7 +11,7 @@ struct module; ...@@ -11,7 +11,7 @@ struct module;
namespace gpu { namespace gpu {
std::vector<stream_race> analyze_streams(const module& p); std::vector<stream_race> analyze_streams(const module& m);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -14,7 +14,7 @@ namespace gpu { ...@@ -14,7 +14,7 @@ namespace gpu {
struct eliminate_workspace struct eliminate_workspace
{ {
std::string name() const { return "eliminate_workspace"; } std::string name() const { return "eliminate_workspace"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct fuse_ops ...@@ -16,7 +16,7 @@ struct fuse_ops
context* ctx = nullptr; context* ctx = nullptr;
bool fast_math = true; bool fast_math = true;
std::string name() const { return "gpu::fuse_ops"; } std::string name() const { return "gpu::fuse_ops"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct prefuse_ops
{
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
...@@ -17,9 +17,9 @@ struct schedule_model ...@@ -17,9 +17,9 @@ struct schedule_model
{ {
std::size_t streams = 0; std::size_t streams = 0;
std::size_t concurrency() const; std::size_t concurrency() const;
void sched(module& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
void wait(module& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
void record(module& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
......
...@@ -15,7 +15,7 @@ namespace gpu { ...@@ -15,7 +15,7 @@ namespace gpu {
struct sync_device struct sync_device
{ {
std::string name() const { return "sync_device"; } std::string name() const { return "sync_device"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -14,7 +14,7 @@ struct write_literals ...@@ -14,7 +14,7 @@ struct write_literals
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "gpu::write_literals"; } std::string name() const { return "gpu::write_literals"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
...@@ -28,7 +29,8 @@ ${preamble} ...@@ -28,7 +29,8 @@ ${preamble}
extern "C" { extern "C" {
__global__ void kernel(${params}) __global__ void kernel(${params})
{ {
pointwise(${lambda}, ${args}); auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
} }
} }
...@@ -41,40 +43,105 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -41,40 +43,105 @@ struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe(const std::vector<shape>& inputs) static std::size_t oversubscribe_if(bool b)
{ {
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); })) if(b)
return 1;
else
return 256; return 256;
else
return 1;
}
static std::size_t find_fast_axis(const std::vector<shape>& inputs)
{
auto permutation = find_permutation(inputs);
auto it = std::max_element(permutation.begin(), permutation.end());
return it - permutation.begin();
}
static std::vector<bool> preload(std::size_t axis, const std::vector<shape>& inputs)
{
const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return result;
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
return result;
}
static std::string preload_str(const std::vector<bool>& bs)
{
std::vector<std::string> bool_strs;
std::transform(bs.begin(), std::prev(bs.end()), std::back_inserter(bool_strs), [](bool b) {
if(b)
return "true";
return "false";
});
return "false, " + join_strings(bool_strs, ", ");
} }
static std::size_t vectorize_elements(const std::vector<shape>& inputs) static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
{ {
std::size_t n = inputs.front().elements(); // If all inputs is half then only use half2
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) { if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.packed() or s.broadcasted(); return s.type() == shape::half_type;
})) }))
{ return {2};
if((n % 4) == 0) return {4, 2};
return n / 4;
else if((n % 2) == 0)
return n / 2;
} }
return n; static auto vectorize_elements(std::size_t axis, const std::vector<shape>& inputs)
{
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
return 1;
});
return *std::min_element(max_vec_size.begin(), max_vec_size.end());
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, vectorize_elements(inputs), oversubscribe(inputs)));
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec_size = vectorize_elements(axis, options.virtual_inputs);
auto preloads = preload(axis, options.virtual_inputs);
auto is_preloading =
std::accumulate(preloads.begin(), preloads.end(), false, std::logical_or<>{});
options.set_launch_params(v,
compute_global_for(ctx,
options.output.elements() / vec_size,
oversubscribe_if(not is_preloading)));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"vec_size", std::to_string(vec_size)},
{"axis", std::to_string(axis)},
{"preloads", preload_str(preloads)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -52,9 +52,8 @@ struct scatternd_compiler : compiler<scatternd_compiler> ...@@ -52,9 +52,8 @@ struct scatternd_compiler : compiler<scatternd_compiler>
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements())); options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
auto out_s = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.output = out_s; options.output = inputs.back();
options.kernel_name = "scatternd_kernel"; options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"}); auto reduction = "assign_" + v.get("reduction", std::string{"none"});
......
...@@ -3,6 +3,14 @@ ...@@ -3,6 +3,14 @@
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/array.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
namespace migraphx { namespace migraphx {
struct swallow struct swallow
...@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs) ...@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs)
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
template <class G, class F>
constexpr auto join(G g, F f)
{
return f([=](auto... xs) { return g(xs...); });
}
template <class G, class F, class... Fs>
constexpr auto join(G g, F f, Fs... fs)
{
return f([=](auto... xs) { return join([=](auto... ys) { return g(xs..., ys...); }, fs...); });
}
template <class Compare, class P1, class P2> template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2) constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{ {
...@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic) ...@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic)
return arg_c<ic>(); return arg_c<ic>();
} }
inline constexpr auto rotate_last() template <class F>
constexpr auto make_transform(F f)
{ {
return [](auto... xs) { return [=](auto... xs) { return [=](auto g) { return f(g, xs...); }; };
return [=](auto&& f) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
};
};
} }
// An arg transformation takes the arguments and then a function to take the new arguments:
// transform(xs...)([](auto... ys) { ... })
// The transform_args function takes a list of transformations and continually applies them
template <class F> template <class F>
constexpr auto transform_args(F f) constexpr auto transform_args(F f)
{ {
return [=](auto... xs) { return f;
return [=](auto g) { return f(xs...)([&](auto... ys) { return g(ys...); }); };
};
} }
template <class F, class... Fs> template <class F, class... Fs>
constexpr auto transform_args(F f, Fs... fs) constexpr auto transform_args(F f, Fs... fs)
{ {
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); }; return make_transform([=](auto g, auto... xs) {
return f(xs...)([=](auto... ys) { return transform_args(fs...)(ys...)(g); });
});
} }
// NOLINTNEXTLINE // identity transform
#define MIGRAPHX_RETURNS(...) \ inline constexpr auto transform_args()
->decltype(__VA_ARGS__) { return __VA_ARGS__; } {
return make_transform([](auto f, auto... xs) { return f(xs...); });
}
// NOLINTNEXTLINE // Rotate the first argument to the last argument
#define MIGRAPHX_LIFT(...) \ inline constexpr auto rotate_last()
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)) {
return make_transform([](auto f, auto... xs) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
});
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
...@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x) ...@@ -38,20 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x)
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), idx.global_stride(out.get_shape().elements(),
[&](auto i) { out[i] = implicit_conversion(f(ps[i]...)); }); [&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); });
});
} }
template <class F, class... Ts> template <class... Transforms>
__device__ void pointwise(F f, Ts*... ps) __device__ auto pointwise(index idx, Transforms... transforms)
{ {
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize()); return [=](auto f, auto*... ps) {
t(ps...)([&](auto... xs) { auto t = transform_args(make_tensors(), rotate_last(), transforms...);
auto idx = make_index(); t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); });
pointwise_tensor(idx, f, xs...); };
});
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx { namespace migraphx {
...@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -73,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{ {
if constexpr(decltype(tensor_vec_size(x)){} == 0) if constexpr(decltype(tensor_vec_size(x)){} == 0)
{ {
auto v = vectorize(x); auto v = auto_vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset); auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(), idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; }); [&](auto i) { b[i] = v.data()[i]; });
...@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs) ...@@ -126,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs)
}; };
} }
inline __device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto out, auto... xs) {
preload<typename decltype(out)::type>(idx, xs...)([&](auto... ys) { f(out, ys...); });
});
}
template <bool B, class T>
__device__ auto preload_copy(index idx, T x)
{
return [=](auto f) {
if constexpr(B)
{
using type = typename T::type;
constexpr auto size = get_shape_c<T>{}.element_space();
__shared__ type buffer[size];
// TODO: Always vecotrize when size > 4, and then use a second loop for remainder
constexpr auto n = find_vectorize_size([&](auto i) { return (size % i) == 0; });
auto input = as_vec<n>(remove_bool(x.data()));
auto b = as_vec<n>(remove_bool(buffer));
idx.local_stride(size / n, [&](auto i) { b[i] = input[i]; });
return f(x.with(buffer));
}
else
{
return f(x);
}
};
}
template <bool... Bs>
__device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto... xs) {
auto invoke = [=](auto... ys) {
__syncthreads();
f(ys...);
};
join(invoke, preload_copy<Bs>(idx, xs)...);
});
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP #endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
...@@ -118,15 +118,13 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs) ...@@ -118,15 +118,13 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
} }
template <class T, class U, class V, class W, class Settings> template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s) __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, Settings s)
{ {
auto index = make_index(); auto index = make_index();
const auto x = x_t.begin(); const auto x = x_t.begin();
const auto rois = rois_t.begin(); const auto rois = rois_t.begin();
const auto ind = ind_t.begin(); const auto ind = ind_t.begin();
auto out_ptr = y_t.begin();
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
...@@ -176,7 +174,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -176,7 +174,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(s.is_avg_pooling) if constexpr(s.is_avg_pooling)
{ {
out_ptr[i] = calc_pooling(offset_x, y_t[i] = calc_pooling(offset_x,
roi_starts, roi_starts,
bin_size, bin_size,
{ph, pw}, {ph, pw},
...@@ -187,7 +185,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -187,7 +185,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
} }
else else
{ {
out_ptr[i] = calc_pooling(offset_x, y_t[i] = calc_pooling(offset_x,
roi_starts, roi_starts,
bin_size, bin_size,
{ph, pw}, {ph, pw},
......
...@@ -60,10 +60,19 @@ constexpr auto common_vec_size() ...@@ -60,10 +60,19 @@ constexpr auto common_vec_size()
})(vec_size<Ts>()...); })(vec_size<Ts>()...);
} }
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T> template <index_int N, class T>
__device__ __host__ auto as_vec(T* x) __device__ __host__ auto as_vec(T* x)
{ {
if constexpr(N == 0) if constexpr(N < 2)
return x; return x;
else else
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
......
...@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis) ...@@ -50,19 +50,10 @@ constexpr auto shape_step(Shape s, Axis)
}); });
} }
// Bools can not be used as a vector type so convert it to uint8
template <class T>
__device__ __host__ T* remove_bool(T* x)
{
return x;
}
inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis) __device__ __host__ auto as_vec(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N < 2)
return x; return x;
else else
return make_tensor_view(as_vec<N>(remove_bool(x.data())), return make_tensor_view(as_vec<N>(remove_bool(x.data())),
...@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis) ...@@ -72,7 +63,7 @@ __device__ __host__ auto as_vec(T x, Axis axis)
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis axis) constexpr auto tensor_step(T x, Axis axis)
{ {
if constexpr(N == 0) if constexpr(N < 2)
{ {
return x; return x;
} }
...@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred) ...@@ -157,11 +148,11 @@ constexpr auto find_vectorize_size(P pred)
else if constexpr(decltype(pred(_c<2>)){}) else if constexpr(decltype(pred(_c<2>)){})
return _c<2>; return _c<2>;
else else
return _c<0>; return _c<1>;
} }
template <class T> template <class T>
__host__ __device__ auto vectorize(T x) __host__ __device__ auto auto_vectorize(T x)
{ {
if constexpr(tensor_vec_size<T>() == 0) if constexpr(tensor_vec_size<T>() == 0)
{ {
...@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs) ...@@ -194,7 +185,7 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
{ {
MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1); MIGRAPHX_ASSERT(s.strides[axis] == 0 or s.strides[axis] == 1);
MIGRAPHX_ASSERT(s.lens[axis] > 0); MIGRAPHX_ASSERT(s.lens[axis] > 0);
MIGRAPHX_ASSERT(n == 0 or s.lens[axis] % n == 0); MIGRAPHX_ASSERT(n == 1 or s.lens[axis] % n == 0);
if constexpr(s.strides[axis] == 0) if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis); return tensor_step<n>(x, axis);
else else
...@@ -215,7 +206,32 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs) ...@@ -215,7 +206,32 @@ inline __device__ __host__ auto auto_vectorize_impl(F f, Ts... xs)
inline __device__ __host__ auto auto_vectorize() inline __device__ __host__ auto auto_vectorize()
{ {
return [](auto... xs) { return [=](auto f) { auto_vectorize_impl(f, xs...); }; }; return make_transform([](auto f, auto... xs) { auto_vectorize_impl(f, xs...); });
}
template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>);
else
return as_vec<N>(x, _c<Axis>);
}
template <index_int N, index_int Axis>
__device__ __host__ auto vectorize()
{
return make_transform([](auto f, auto... xs) {
if constexpr(N < 2)
{
f(xs...);
}
else
{
f(vectorize_tensor<N, Axis>(xs)...);
}
});
} }
} // namespace migraphx } // namespace migraphx
......
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard())
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
}
};
struct find_triaddlayernorm
{
auto matcher() const
{
auto add1 =
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["z1"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
}
};
} // namespace
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment