Unverified Commit d9a5acbd authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into jit-vector-reduce

parents d0b7fc9a a27dd28c
...@@ -352,7 +352,7 @@ struct cpu_apply ...@@ -352,7 +352,7 @@ struct cpu_apply
std::transform(bind_inputs.begin(), std::transform(bind_inputs.begin(),
bind_inputs.end(), bind_inputs.end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[&](const auto& s) { return r.instructions.at(s); }); [&](const auto& s) { return r.instructions[s]; });
inputs.push_back(this->insert_allocation(ins, ins->get_shape())); inputs.push_back(this->insert_allocation(ins, ins->get_shape()));
modl->replace_instruction(ins, op, inputs); modl->replace_instruction(ins, op, inputs);
}); });
......
...@@ -159,6 +159,7 @@ add_library(migraphx_gpu ...@@ -159,6 +159,7 @@ add_library(migraphx_gpu
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp pad.cpp
pooling.cpp pooling.cpp
quant_convolution.cpp quant_convolution.cpp
......
...@@ -28,30 +28,30 @@ struct hip_stream_model ...@@ -28,30 +28,30 @@ struct hip_stream_model
bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; } bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; }
}; };
stream_model make_stream_model(const module& p) stream_model make_stream_model(const module& m)
{ {
hip_stream_model m; hip_stream_model hsm;
std::size_t stream = 0; std::size_t stream = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "gpu::set_stream") if(ins->name() == "gpu::set_stream")
{ {
auto v = ins->get_operator().to_value(); auto v = ins->get_operator().to_value();
stream = v["stream"].to<std::size_t>(); stream = v["stream"].to<std::size_t>();
m.max_stream = std::max(stream, m.max_stream); hsm.max_stream = std::max(stream, hsm.max_stream);
} }
if(ins->get_operator().is_context_free()) if(ins->get_operator().is_context_free())
continue; continue;
if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name())) if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name()))
continue; continue;
m.ins2stream[ins] = stream; hsm.ins2stream[ins] = stream;
} }
return m; return hsm;
} }
std::vector<stream_race> analyze_streams(const module& p) std::vector<stream_race> analyze_streams(const module& m)
{ {
return migraphx::analyze_streams(p, make_stream_model(p)); return migraphx::analyze_streams(m, make_stream_model(m));
} }
} // namespace gpu } // namespace gpu
......
...@@ -11,11 +11,11 @@ namespace migraphx { ...@@ -11,11 +11,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void eliminate_workspace::apply(module& p) const void eliminate_workspace::apply(module& m) const
{ {
std::size_t n = 0; std::size_t n = 0;
std::vector<instruction_ref> allocs; std::vector<instruction_ref> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->outputs().size() != 1) if(ins->outputs().size() != 1)
continue; continue;
...@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const ...@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const
} }
if(n > 0) if(n > 0)
{ {
auto ws = p.add_parameter("workspace", shape{shape::int8_type, {n}}); auto ws = m.add_parameter("workspace", shape{shape::int8_type, {n}});
for(auto&& a : allocs) for(auto&& a : allocs)
{ {
p.replace_instruction(a, ws); m.replace_instruction(a, ws);
p.remove_instruction(a); m.remove_instruction(a);
} }
} }
} }
......
...@@ -316,7 +316,7 @@ struct find_layernorm ...@@ -316,7 +316,7 @@ struct find_layernorm
{ {
auto matcher() const { return match::layernorm(&gpu_name); } auto matcher() const { return match::layernorm(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -331,7 +331,7 @@ struct find_layernorm ...@@ -331,7 +331,7 @@ struct find_layernorm
if(relements > 1024 or (relements % 4 != 0 and relements > 256)) if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return; return;
p.replace_instruction(ins, hip_layernorm{}, x_ins, args.back()); m.replace_instruction(ins, hip_layernorm{}, x_ins, args.back());
} }
}; };
...@@ -343,11 +343,11 @@ struct find_triadd_layernorm ...@@ -343,11 +343,11 @@ struct find_triadd_layernorm
match::used_once(), match::all_of[match::inputs()](match::standard_shape())))); match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto triadd = ins->inputs().front(); auto triadd = ins->inputs().front();
p.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs()); m.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs());
} }
}; };
...@@ -355,13 +355,13 @@ struct find_gelu ...@@ -355,13 +355,13 @@ struct find_gelu
{ {
auto matcher() const { return match::gelu_erf(&gpu_name); } auto matcher() const { return match::gelu_erf(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
} }
}; };
...@@ -372,7 +372,7 @@ struct find_add_gelu ...@@ -372,7 +372,7 @@ struct find_add_gelu
return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -381,7 +381,7 @@ struct find_add_gelu ...@@ -381,7 +381,7 @@ struct find_add_gelu
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu{}, args); m.replace_instruction(ins, hip_add_gelu{}, args);
} }
}; };
...@@ -391,16 +391,16 @@ struct find_gelu_new ...@@ -391,16 +391,16 @@ struct find_gelu_new
auto matcher() const { return match::gelu_tanh(&gpu_name); } auto matcher() const { return match::gelu_tanh(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
if(fast_math) if(fast_math)
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
else else
p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
} }
}; };
...@@ -411,7 +411,7 @@ struct find_add_gelu_new ...@@ -411,7 +411,7 @@ struct find_add_gelu_new
return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -420,7 +420,7 @@ struct find_add_gelu_new ...@@ -420,7 +420,7 @@ struct find_add_gelu_new
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu_new{}, args); m.replace_instruction(ins, hip_add_gelu_new{}, args);
} }
}; };
...@@ -435,7 +435,7 @@ struct find_add_clip ...@@ -435,7 +435,7 @@ struct find_add_clip
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -448,9 +448,9 @@ struct find_add_clip ...@@ -448,9 +448,9 @@ struct find_add_clip
add_args.pop_back(); add_args.pop_back();
add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end()); add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_clip{}, add_args); m.replace_instruction(ins, hip_add_clip{}, add_args);
else if(add_ins->name() == "gpu::triadd") else if(add_ins->name() == "gpu::triadd")
p.replace_instruction(ins, hip_triadd_clip{}, add_args); m.replace_instruction(ins, hip_triadd_clip{}, add_args);
} }
}; };
...@@ -470,7 +470,7 @@ struct find_add_unary ...@@ -470,7 +470,7 @@ struct find_add_unary
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -481,9 +481,9 @@ struct find_add_unary ...@@ -481,9 +481,9 @@ struct find_add_unary
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, binary_add_op, args); m.replace_instruction(ins, binary_add_op, args);
else if(add_ins->name() == "gpu::triadd") else if(add_ins->name() == "gpu::triadd")
p.replace_instruction(ins, ternary_add_op, args); m.replace_instruction(ins, ternary_add_op, args);
} }
}; };
...@@ -498,7 +498,7 @@ struct find_triadd ...@@ -498,7 +498,7 @@ struct find_triadd
.bind("input"))); .bind("input")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
...@@ -513,7 +513,7 @@ struct find_triadd ...@@ -513,7 +513,7 @@ struct find_triadd
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args); m.replace_instruction(ins, hip_triadd{}, args);
} }
}; };
...@@ -525,7 +525,7 @@ struct find_mul_add ...@@ -525,7 +525,7 @@ struct find_mul_add
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b"))); match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto mul_ins = r.instructions["mul"]; auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
...@@ -538,7 +538,7 @@ struct find_mul_add ...@@ -538,7 +538,7 @@ struct find_mul_add
args.insert(std::prev(args.end()), b_ins); args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args); m.replace_instruction(ins, hip_mul_add{}, args);
} }
}; };
...@@ -550,7 +550,7 @@ struct find_mul_add_relu ...@@ -550,7 +550,7 @@ struct find_mul_add_relu
match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add"))); match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto mul_add_ins = r.instructions["mul_add"]; auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result; auto ins = r.result;
...@@ -558,7 +558,7 @@ struct find_mul_add_relu ...@@ -558,7 +558,7 @@ struct find_mul_add_relu
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args); m.replace_instruction(ins, hip_mul_add_relu{}, args);
} }
}; };
...@@ -783,7 +783,7 @@ auto conv_bias(Ms... ms) ...@@ -783,7 +783,7 @@ auto conv_bias(Ms... ms)
} }
template <class Op> template <class Op>
void apply_conv_bias(context& ctx, module& p, match::matcher_result r) void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -798,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r) ...@@ -798,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.get_workspace(ctx); auto ws = cb.get_workspace(ctx);
(void)ws; (void)ws;
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT inline auto precompile_name(std::string s) // NOLINT
...@@ -829,9 +829,9 @@ struct find_conv_bias ...@@ -829,9 +829,9 @@ struct find_conv_bias
match::output(match::name(std::unordered_set<std::string>{"gpu::relu"})))); match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias>(*ctx, m, r);
} }
}; };
...@@ -840,9 +840,9 @@ struct find_conv_bias_relu ...@@ -840,9 +840,9 @@ struct find_conv_bias_relu
context* ctx = nullptr; context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); } auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias_relu>(*ctx, m, r);
} }
}; };
...@@ -857,7 +857,7 @@ struct find_conv_pointwise ...@@ -857,7 +857,7 @@ struct find_conv_pointwise
fusable_conv(match::used_once()).bind("conv"))); fusable_conv(match::used_once()).bind("conv")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -896,7 +896,7 @@ struct find_gemm_add ...@@ -896,7 +896,7 @@ struct find_gemm_add
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm_ins = r.instructions["gemm"]; auto gemm_ins = r.instructions["gemm"];
...@@ -919,15 +919,15 @@ struct find_gemm_add ...@@ -919,15 +919,15 @@ struct find_gemm_add
auto copy_ins = c_ins; auto copy_ins = c_ins;
// Insert copy // Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty()) if(ins == m.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{ {
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back()); copy_ins = m.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
} }
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
gemm.beta = 1; gemm.beta = 1;
p.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
} }
}; };
...@@ -938,22 +938,22 @@ struct find_commutative_broadcast ...@@ -938,22 +938,22 @@ struct find_commutative_broadcast
return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape())); return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape()));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto args = ins->inputs(); auto args = ins->inputs();
move_broadcasted_back(args); move_broadcasted_back(args);
p.replace_instruction(ins, ins->get_operator(), args); m.replace_instruction(ins, ins->get_operator(), args);
} }
}; };
void fuse_ops::apply(module& p) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(p, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_gelu{}, find_gelu_new{fast_math});
run_passes(p, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(p, find_triadd{}); match::find_matches(m, find_triadd{});
match::find_matches(p, match::find_matches(m,
find_layernorm{}, find_layernorm{},
find_conv_pointwise{ctx}, find_conv_pointwise{ctx},
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
...@@ -966,8 +966,8 @@ void fuse_ops::apply(module& p) const ...@@ -966,8 +966,8 @@ void fuse_ops::apply(module& p) const
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}}, find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(p, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{}); match::find_matches(m, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -11,7 +11,7 @@ struct module; ...@@ -11,7 +11,7 @@ struct module;
namespace gpu { namespace gpu {
std::vector<stream_race> analyze_streams(const module& p); std::vector<stream_race> analyze_streams(const module& m);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -14,7 +14,7 @@ namespace gpu { ...@@ -14,7 +14,7 @@ namespace gpu {
struct eliminate_workspace struct eliminate_workspace
{ {
std::string name() const { return "eliminate_workspace"; } std::string name() const { return "eliminate_workspace"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct fuse_ops ...@@ -16,7 +16,7 @@ struct fuse_ops
context* ctx = nullptr; context* ctx = nullptr;
bool fast_math = true; bool fast_math = true;
std::string name() const { return "gpu::fuse_ops"; } std::string name() const { return "gpu::fuse_ops"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct prefuse_ops
{
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
...@@ -17,9 +17,9 @@ struct schedule_model ...@@ -17,9 +17,9 @@ struct schedule_model
{ {
std::size_t streams = 0; std::size_t streams = 0;
std::size_t concurrency() const; std::size_t concurrency() const;
void sched(module& p, instruction_ref ins, std::size_t n) const; void sched(module& m, instruction_ref ins, std::size_t n) const;
void wait(module& p, instruction_ref ins, std::size_t wait_id) const; void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
void record(module& p, instruction_ref ins, std::size_t wait_id) const; void record(module& m, instruction_ref ins, std::size_t wait_id) const;
std::size_t weight(const operation& op) const; std::size_t weight(const operation& op) const;
}; };
......
...@@ -15,7 +15,7 @@ namespace gpu { ...@@ -15,7 +15,7 @@ namespace gpu {
struct sync_device struct sync_device
{ {
std::string name() const { return "sync_device"; } std::string name() const { return "sync_device"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -14,7 +14,7 @@ struct write_literals ...@@ -14,7 +14,7 @@ struct write_literals
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "gpu::write_literals"; } std::string name() const { return "gpu::write_literals"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace gpu } // namespace gpu
......
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard())
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
}
};
struct find_triaddlayernorm
{
auto matcher() const
{
auto add1 =
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["z1"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
}
};
} // namespace
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event) ...@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event)
MIGRAPHX_REGISTER_OP(set_stream) MIGRAPHX_REGISTER_OP(set_stream)
std::size_t schedule_model::concurrency() const { return streams; } std::size_t schedule_model::concurrency() const { return streams; }
void schedule_model::sched(module& p, instruction_ref ins, std::size_t n) const void schedule_model::sched(module& m, instruction_ref ins, std::size_t n) const
{ {
auto last_stream = std::find_if(std::make_reverse_iterator(ins), auto last_stream = std::find_if(std::make_reverse_iterator(ins),
std::make_reverse_iterator(p.begin()), std::make_reverse_iterator(m.begin()),
[&](auto&& i) { return i.name() == "gpu::set_stream"; }); [&](auto&& i) { return i.name() == "gpu::set_stream"; });
if(last_stream != std::make_reverse_iterator(p.begin())) if(last_stream != std::make_reverse_iterator(m.begin()))
{ {
auto&& op = any_cast<set_stream>(last_stream->get_operator()); auto&& op = any_cast<set_stream>(last_stream->get_operator());
// If the same stream was set earlier then skip // If the same stream was set earlier then skip
if(op.stream == n) if(op.stream == n)
return; return;
} }
p.insert_instruction(ins, set_stream{n}); m.insert_instruction(ins, set_stream{n});
} }
void schedule_model::wait(module& p, instruction_ref ins, std::size_t wait_id) const void schedule_model::wait(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
p.insert_instruction(ins, wait_event{wait_id}); m.insert_instruction(ins, wait_event{wait_id});
} }
void schedule_model::record(module& p, instruction_ref ins, std::size_t wait_id) const void schedule_model::record(module& m, instruction_ref ins, std::size_t wait_id) const
{ {
p.insert_instruction(std::next(ins), record_event{wait_id}); m.insert_instruction(std::next(ins), record_event{wait_id});
} }
static std::unordered_map<std::string, std::size_t> create_weight_map() static std::unordered_map<std::string, std::size_t> create_weight_map()
......
...@@ -8,9 +8,9 @@ namespace migraphx { ...@@ -8,9 +8,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void sync_device::apply(module& p) const void sync_device::apply(module& m) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(m.end());
if(last->name() == "@return") if(last->name() == "@return")
{ {
auto inputs = last->inputs(); auto inputs = last->inputs();
...@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const ...@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const
return (i->name() == "hip::copy_from_gpu"); return (i->name() == "hip::copy_from_gpu");
})) }))
{ {
auto sync_in = p.insert_instruction(last, make_op("hip::sync_stream"), inputs); auto sync_in = m.insert_instruction(last, make_op("hip::sync_stream"), inputs);
if(not inputs.empty()) if(not inputs.empty())
{ {
p.replace_instruction(inputs.front(), sync_in); m.replace_instruction(inputs.front(), sync_in);
} }
} }
} }
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp> #include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp> #include <migraphx/gpu/mlir_conv.hpp>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
...@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
......
...@@ -11,25 +11,25 @@ namespace gpu { ...@@ -11,25 +11,25 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
void write_literals::apply(module& p) const void write_literals::apply(module& m) const
{ {
assert(ctx != nullptr); assert(ctx != nullptr);
std::size_t n = 0; std::size_t n = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
if(enabled(MIGRAPHX_COPY_LITERALS{})) if(enabled(MIGRAPHX_COPY_LITERALS{}))
{ {
literal l = ins->get_literal(); literal l = ins->get_literal();
auto pre = p.add_literal(l); auto pre = m.add_literal(l);
auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()}); auto alloc = m.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc); m.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
} }
else else
{ {
std::string id = p.name() + ":@literal:" + std::to_string(n); std::string id = m.name() + ":@literal:" + std::to_string(n);
p.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id}); m.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
n++; n++;
} }
} }
......
...@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3) ...@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3)
EXPECT(result == migraphx::literal{0}); EXPECT(result == migraphx::literal{0});
} }
TEST_CASE(reused_twice)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 2};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), add2);
auto mean_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add2, mean_mbcast);
auto exponent_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = mm->add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = mm->add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
mm->add_instruction(migraphx::make_op("sqrt"), add_epsilon);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
}
TEST_CASE(unused_module) TEST_CASE(unused_module)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1) ...@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1)
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y"))); match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any2) TEST_CASE(match_either_args_any2)
...@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2) ...@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2)
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y"))); match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any3) TEST_CASE(match_either_args_any3)
...@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3) ...@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3)
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y"))); match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any4) TEST_CASE(match_either_args_any4)
...@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4) ...@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4)
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y"))); match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any5) TEST_CASE(match_either_args_any5)
...@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5) ...@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5)
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y"))); match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_all_of1) TEST_CASE(match_all_of1)
...@@ -747,10 +747,10 @@ TEST_CASE(match_bind1) ...@@ -747,10 +747,10 @@ TEST_CASE(match_bind1)
match::standard_shape()) match::standard_shape())
.bind("pass"); .bind("pass");
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one}); EXPECT(bool{r.instructions["one"] == one});
EXPECT(bool{r.instructions.at("two") == two}); EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions.at("sum") == sum}); EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions.at("pass") == pass}); EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
...@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2) ...@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2)
match::standard_shape()) match::standard_shape())
.bind("pass"); .bind("pass");
auto r = find_match(*child, m); auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two}); EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions.at("sum") == sum}); EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions.at("pass") == pass}); EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
......
...@@ -4,12 +4,20 @@ ...@@ -4,12 +4,20 @@
set -e set -e
#install pip3, rocm-cmake, rocblas and miopen export LC_ALL=C.UTF-8
apt update && apt install -y python3-pip rocm-cmake rocblas miopen-hip openmp-extras export LANG=C.UTF-8
# Need pip3 and Python headers to build dependencies
apt update && apt install -y python3-pip python3-dev cmake rocm-cmake rocblas miopen-hip openmp-extras
# Needed for cmake to build various pip packages
pip3 install setuptools wheel
# install rbuild to build dependencies # install rbuild to build dependencies
pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
PREFIX=/usr/local PREFIX=/usr/local
REQ_FILE_DIR="" REQ_FILE_DIR=""
if [ "$#" -ge 2 ]; then if [ "$#" -ge 2 ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment