Commit f1c8e6c9 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-integration-tuning

parents d09b7682 c1b8c975
...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref ...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
if(ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// "rep" instruction could be used earlier in the program and moving it at the end
// may cause invalid program, therefore make an identity operation in this case.
return replace_instruction(ins, make_op("identity"), rep); return replace_instruction(ins, make_op("identity"), rep);
} }
...@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const ...@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
return end(); return end();
} }
void module::finalize(context& ctx) void module::finalize(std::vector<context>& contexts)
{ {
assert(not contexts.empty());
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{}); const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
...@@ -660,10 +663,10 @@ void module::finalize(context& ctx) ...@@ -660,10 +663,10 @@ void module::finalize(context& ctx)
std::cout << "Finalize: "; std::cout << "Finalize: ";
this->debug_print(ins); this->debug_print(ins);
} }
ins->finalize(ctx); ins->finalize(contexts[ins->get_target_id()]);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
smod->finalize(ctx); smod->finalize(contexts);
} }
} }
......
...@@ -38,6 +38,9 @@ ...@@ -38,6 +38,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ONNX_PARSER)
static shape shape_from_dyn_dims(shape::type_t shape_type, static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims) const std::vector<shape::dynamic_dimension>& dyn_dims)
...@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type, ...@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type,
return {shape_type, dyn_dims}; return {shape_type, dyn_dims};
} }
namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
{ {
std::unordered_map<std::string, onnx::AttributeProto> result; std::unordered_map<std::string, onnx::AttributeProto> result;
...@@ -297,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -297,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version; return version;
} }
std::vector<instruction_ref> void print_added_instructions(module* mod,
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining) const std::vector<instruction_ref>& args,
const std::vector<instruction_ref>& result)
{
// Print instructions added by the parser not in args
std::vector<instruction_ref> added_instructions;
fix([&](auto self, auto r) {
for(auto ins : r)
{
if(contains(args, ins))
continue;
if(contains(added_instructions, ins))
continue;
self(ins->inputs());
added_instructions.push_back(ins);
}
})(result);
mod->debug_print(added_instructions);
}
std::unordered_map<std::string, instruction_ref>
parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "initializer: " << f.name() << std::endl;
// backup instructions in parent mod // backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f)); mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f));
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
mod->debug_print(mod_insts[f.name()]);
} }
return mod_insts;
}
std::unordered_map<std::string, instruction_ref>
parse_inputs(const onnx_parser& parser,
module* mod,
const onnx::GraphProto& graph,
std::unordered_map<std::string, instruction_ref> mod_insts)
{
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
...@@ -317,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -317,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph. // name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that. // In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name)) if(contains(parser.instructions, name))
{ {
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name + MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!"); "\" existing in parent graph!");
...@@ -325,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -325,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
shape s; shape s;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) if(parser.map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); dims = parser.map_input_dims.at(name);
s = parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
else if(map_dyn_input_dims.count(name) > 0) else if(parser.map_dyn_input_dims.count(name) > 0)
{ {
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name)); s = shape_from_dyn_dims(shape_type, parser.map_dyn_input_dims.at(name));
} }
else else
{ {
s = parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
return mod_insts;
}
std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{
std::unordered_map<std::string, instruction_ref> mod_insts =
parse_intializer(*this, mod, graph);
mod_insts = parse_inputs(*this, mod, graph, mod_insts);
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end())); std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "operator: " << node.op_type() << std::endl;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
...@@ -384,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -384,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
result.begin(), result.begin(),
std::inserter(instructions, instructions.end()), std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(x, y); }); [](auto&& x, auto&& y) { return std::make_pair(x, y); });
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
{
print_added_instructions(mod, args, result);
}
} }
// Find instructions corresponding to the output // Find instructions corresponding to the output
......
...@@ -31,8 +31,6 @@ namespace migraphx { ...@@ -31,8 +31,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
struct parse_where : op_parser<parse_where> struct parse_where : op_parser<parse_where>
{ {
std::vector<op_desc> operators() const { return {{"Where"}}; } std::vector<op_desc> operators() const { return {{"Where"}}; }
...@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where> ...@@ -59,13 +57,6 @@ struct parse_where : op_parser<parse_where>
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(enabled(MIGRAPHX_ENABLE_CK{}))
{
// Convert condition tensor to int32 to work around CK not supporting bool type
args[0] = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), args[0]);
}
if(args[0]->get_shape().lens() != lens) if(args[0]->get_shape().lens() != lens)
{ {
args[0] = args[0] =
......
...@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager struct module_pm : module_pass_manager
{ {
module* mod = nullptr; module* mod = nullptr;
module* root_mod = nullptr;
tracer* t = nullptr; tracer* t = nullptr;
module* common_parent = nullptr; module* common_parent = nullptr;
program* prog = nullptr; program* prog = nullptr;
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {} module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
module_pm(module* pmod = nullptr, module* rmod = nullptr, tracer* pt = nullptr)
: mod(pmod), root_mod(rmod), t(pt)
{
}
template <class... Ts> template <class... Ts>
void trace(Ts&&... xs) const void trace(Ts&&... xs) const
{ {
...@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager ...@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual module* get_root_module() override virtual module* get_root_module() override
{ {
if(root_mod != nullptr)
return root_mod;
assert(prog); assert(prog);
return prog->get_main_module(); return prog->get_main_module();
} }
...@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas ...@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas
continue; continue;
if(not visited.insert(mod).second) if(not visited.insert(mod).second)
continue; continue;
module_pm mpm{mod, &trace}; module_pm mpm{mod, root_mod, &trace};
mpm.prog = &prog; mpm.prog = &prog;
auto parents = range(tree.equal_range(mod)); auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents); auto nparents = distance(parents);
...@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) ...@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace = tracer{std::cout}; trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
module_pm{&mod, &trace}.run_pass(p); module_pm{&mod, &mod, &trace}.run_pass(p);
} }
} }
......
...@@ -70,9 +70,8 @@ struct program_impl ...@@ -70,9 +70,8 @@ struct program_impl
{ {
// A map is used to keep references to modules of the program // A map is used to keep references to modules of the program
std::unordered_map<std::string, module> modules; std::unordered_map<std::string, module> modules;
context ctx;
std::string target_name;
std::vector<context> contexts; std::vector<context> contexts;
std::vector<target> targets;
}; };
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); } program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
...@@ -96,14 +95,8 @@ void program::assign(const program& p) ...@@ -96,14 +95,8 @@ void program::assign(const program& p)
{ {
impl = std::make_unique<program_impl>(); impl = std::make_unique<program_impl>();
} }
else if(not impl->modules.empty())
{
impl->modules.clear();
}
impl->ctx = p.impl->ctx; *impl = *p.impl;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
// build a map from old ins to new ins // build a map from old ins to new ins
// Build a map from old module to new module // Build a map from old module to new module
...@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const ...@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const
return mm->get_output_shapes(); return mm->get_output_shapes();
} }
context& program::get_context() const { return impl->ctx; } context& program::get_context() const
{
assert(impl->contexts.size() == 1);
return impl->contexts.front();
}
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
...@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta ...@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return p; return p;
} }
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->contexts.empty(); }
void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts) void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts)
{ {
...@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op ...@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() + MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() +
" from instruction " + std::to_string(index)); " from instruction " + std::to_string(index));
} }
current_mod->finalize(this->impl->contexts[root_target_id]);
} }
} }
this->finalize();
} }
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
{ {
// todo: combine with multi-target compile method // todo: combine with multi-target compile method
assert(not this->is_compiled()); assert(not this->is_compiled());
this->impl->target_name = t.name(); this->impl->targets = {t};
this->impl->ctx = t.get_context(); this->impl->contexts = {t.get_context()};
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout}; options.trace = tracer{std::cout};
options.trace(*this); options.trace(*this);
options.trace(); options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->contexts.front(), options);
run_passes(*this, passes, options.trace); run_passes(*this, passes, options.trace);
auto mods = this->get_modules(); auto mods = this->get_modules();
// Validate and finalize // Validate and finalize
...@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options) ...@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " + MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index)); std::to_string(index));
} }
mod->finalize(this->impl->ctx); mod->finalize(this->impl->contexts);
} }
} }
void program::finalize() void program::finalize()
{ {
auto* mm = this->get_main_module(); auto* mm = this->get_main_module();
mm->finalize(this->impl->ctx); mm->finalize(this->impl->contexts);
} }
template <class T> template <class T>
...@@ -359,6 +356,31 @@ std::string classify(T x) ...@@ -359,6 +356,31 @@ std::string classify(T x)
} }
} }
void print_statistics(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
os << "Min value: " << *std::min_element(t.begin(), t.end()) << ", ";
os << "Max value: " << *std::max_element(t.begin(), t.end()) << ", ";
double num_elements = t.size();
auto mean = std::accumulate(t.begin(), t.end(), 0.0) / num_elements;
auto stddev = std::sqrt(
std::accumulate(t.begin(),
t.end(),
0.0,
[&](auto r, auto v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
os << "Mean: " << mean << ", ";
os << "StdDev: " << stddev << "\n";
},
[&](const auto& xs) {
for(const auto& x : xs)
{
print_statistics(os, x);
}
});
}
std::unordered_set<std::string> classify_argument(const argument& a) std::unordered_set<std::string> classify_argument(const argument& a)
{ {
std::unordered_set<std::string> result; std::unordered_set<std::string> result;
...@@ -404,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a) ...@@ -404,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a)
template <class F> template <class F>
std::vector<argument> generic_eval(const module* mod, std::vector<argument> generic_eval(const module* mod,
context& ctx, std::vector<context>& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results, std::unordered_map<instruction_ref, argument> results,
F make_trace) F trace)
{ {
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2); results.reserve(mod->size() * 2);
std::vector<argument> values; std::vector<argument> values;
values.reserve(16); values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
assert(results.find(ins) == results.end()); assert(results.find(ins) == results.end());
...@@ -469,14 +490,19 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -469,14 +490,19 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx; return generic_eval(smod, ctx, inputs, results, trace);
return generic_eval(smod, ssctx, inputs, results, make_trace);
}; };
results.emplace(ins, trace(ins, [&] { results.emplace(
return ins->normalized_operator().compute( ins, trace(ins, [&] {
ctx, ins->get_shape(), values, mod_args, module_eval); auto op = ins->normalized_operator();
})); if(op.is_context_free())
return op.compute(ins->get_shape(), values, mod_args, module_eval);
if(ins->get_target_id() >= ctx.size())
MIGRAPHX_THROW("No context available for " + op.name());
return op.compute(
ctx[ins->get_target_id()], ins->get_shape(), values, mod_args, module_eval);
}));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
if(not ins->get_shape().any_of_dynamic()) if(not ins->get_shape().any_of_dynamic())
...@@ -489,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -489,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod,
template <class F> template <class F>
std::vector<argument> generic_eval(const program& p, std::vector<argument> generic_eval(const program& p,
context& ctx, std::vector<context>& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
F make_trace) F trace)
{ {
const module* mm = p.get_main_module(); const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, make_trace); return generic_eval(mm, ctx, params, {}, trace);
} }
std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{ {
auto& ctx = this->impl->ctx; auto& contexts = this->impl->contexts;
#ifndef NDEBUG
auto with_check_context = [&](auto f) {
return [=, &ctx](auto&&) {
auto sctx = std::make_shared<context>(ctx);
auto check_context = [=, &ctx](auto g) {
assert(is_shared(ctx, *sctx));
auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
};
#else
auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret; std::vector<argument> ret;
if(exec_env.async) if(exec_env.async)
{ {
ctx.wait_for(exec_env.queue); assert(contexts.size() == 1);
contexts.front().wait_for(exec_env.queue);
} }
if(trace_level > 0) if(trace_level > 0)
...@@ -538,82 +545,79 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -538,82 +545,79 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction::print(ss, x, ins_names); instruction::print(ss, x, ins_names);
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
ret = generic_eval( auto& ctx = contexts[ins->get_target_id()];
*this, ctx.finish();
ctx, std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
std::move(params), timer t{};
with_check_context([&](auto& ins, auto f, auto&& check_context) { auto result = f();
ctx.finish(); double t1 = t.record<milliseconds>();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; ctx.finish();
timer t{}; double t2 = t.record<milliseconds>();
auto result = check_context(f); std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
double t1 = t.record<milliseconds>(); if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load" and
ctx.finish(); not result.empty())
double t2 = t.record<milliseconds>(); {
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl; migraphx::argument buffer;
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load" and try
not result.empty())
{ {
migraphx::argument buffer; const target& tgt = this->impl->targets.at(ins->get_target_id());
try buffer = tgt.copy_from(result);
{
target tgt = make_target(this->impl->target_name);
buffer = tgt.copy_from(result);
}
catch(const migraphx::exception&)
{
// instruction was run on host then no need to copy buffer from target
buffer = result;
}
catch(...)
{
MIGRAPHX_THROW(
"MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
}
if(trace_level == 2)
{
std::cout << "Output has " << to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
} }
return result; catch(const migraphx::exception&)
})); {
// instruction was run on host then no need to copy buffer from target
buffer = result;
}
catch(...)
{
MIGRAPHX_THROW("MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
}
if(trace_level == 2)
{
std::cout << "Output has " << to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
print_statistics(std::cout, buffer);
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
});
} }
else else
{ {
ret = generic_eval(*this, ret = generic_eval(*this, contexts, std::move(params), [&](auto&&, auto f) { return f(); });
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
} }
if(exec_env.async) if(exec_env.async)
{ {
ctx.finish_on(exec_env.queue); assert(contexts.size() == 1);
contexts.front().finish_on(exec_env.queue);
} }
return ret; return ret;
} }
const int program_file_version = 5; void program::finish() const
{
for(const auto& ctx : this->impl->contexts)
ctx.finish();
}
const int program_file_version = 6;
value program::to_value() const value program::to_value() const
{ {
value result; value result;
result["version"] = program_file_version; result["version"] = program_file_version;
result["target"] = this->impl->target_name; result["targets"] = migraphx::to_value(this->impl->targets);
if(not this->impl->target_name.empty()) result["contexts"] = migraphx::to_value(this->impl->contexts);
result["context"] = this->impl->ctx.to_value();
value module_vals = value::object{}; value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
...@@ -742,12 +746,12 @@ void program::from_value(const value& v) ...@@ -742,12 +746,12 @@ void program::from_value(const value& v)
MIGRAPHX_THROW("Warning: Program version mismatch"); MIGRAPHX_THROW("Warning: Program version mismatch");
} }
this->impl->target_name = v.at("target").to<std::string>(); migraphx::from_value(v.at("targets"), this->impl->targets);
if(not this->impl->target_name.empty())
for(auto i : range(this->impl->targets.size()))
{ {
target t = make_target(this->impl->target_name); this->impl->contexts.push_back(this->impl->targets[i].get_context());
this->impl->ctx = t.get_context(); this->impl->contexts.back().from_value(v.at("contexts")[i]);
this->impl->ctx.from_value(v.at("context"));
} }
auto module_vals = v.at("modules"); auto module_vals = v.at("modules");
...@@ -768,7 +772,9 @@ void program::from_value(const value& v) ...@@ -768,7 +772,9 @@ void program::from_value(const value& v)
auto* mm = get_main_module(); auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods); mod_from_val(mm, module_vals, map_insts, map_mods);
this->finalize(); // Finalize a compiled model
if(not this->impl->contexts.empty())
this->finalize();
} }
double common_average(const std::vector<double>& v) double common_average(const std::vector<double>& v)
...@@ -788,19 +794,19 @@ std::string perf_group(const operation& op) ...@@ -788,19 +794,19 @@ std::string perf_group(const operation& op)
void program::mark(const parameter_map& params, marker&& m) void program::mark(const parameter_map& params, marker&& m)
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); this->finish();
// Start marking // Start marking
m.mark_start(*this); m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
m.mark_start(ins); m.mark_start(ins);
result = f(); result = f();
m.mark_stop(ins); m.mark_stop(ins);
return result; return result;
})); });
m.mark_stop(*this); m.mark_stop(*this);
} }
...@@ -809,10 +815,10 @@ void program::perf_report(std::ostream& os, ...@@ -809,10 +815,10 @@ void program::perf_report(std::ostream& os,
parameter_map params, parameter_map params,
std::size_t batch) const std::size_t batch) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); this->finish();
// Run and time entire program // Run and time entire program
std::vector<double> total_vec; std::vector<double> total_vec;
total_vec.reserve(n); total_vec.reserve(n);
...@@ -820,28 +826,28 @@ void program::perf_report(std::ostream& os, ...@@ -820,28 +826,28 @@ void program::perf_report(std::ostream& os,
{ {
total_vec.push_back(time<milliseconds>([&] { total_vec.push_back(time<milliseconds>([&] {
eval(params); eval(params);
ctx.finish(); this->finish();
})); }));
} }
std::sort(total_vec.begin(), total_vec.end()); std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec; std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map // Fill the map
generic_eval(*this, ctx, params, always([&](auto ins, auto) { generic_eval(*this, ctx, params, [&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{ins->get_shape(), nullptr}; return argument{ins->get_shape(), nullptr};
})); });
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { ins_vec[ins].push_back(time<milliseconds>([&] {
result = f(); result = f();
ctx.finish(); this->impl->contexts[ins->get_target_id()].finish();
})); }));
return result; return result;
})); });
} }
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end()); std::sort(p.second.begin(), p.second.end());
...@@ -1009,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const ...@@ -1009,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) { generic_eval(*this, ctx, std::move(params), [](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr}; return argument{ins->get_shape(), nullptr};
})); });
} }
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
......
...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const ...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module(); module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module(); module_ref root_module = mpm.get_root_module();
if(m.name() == "main") if(m == *root_module)
return; return;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
......
...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
auto mod_inputs = ins->module_inputs(); auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape(); auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16 // Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names ...@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
ins, make_op("convert", {{"target_type", shape::half_type}}), input); ins, make_op("convert", {{"target_type", shape::half_type}}), input);
}); });
// Replace inputs // Insert quantized ins
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs); auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
// Convert back to original type after quantizing
if(mod_inputs.empty())
{
converted_ins = m.insert_instruction(
ins, make_op("convert", {{"target_type", s.type()}}), converted_ins);
}
// Replace original instruction
m.replace_instruction(ins, converted_ins);
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
...@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio ...@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args); mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
} }
void replace_allocate::apply(module& m) const void replace_allocate::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module();
auto mod_output_names = create_output_names(m); auto mod_output_names = create_output_names(m);
bool main_offload_copy = m.name() == "main" ? this->offload_copy : false; bool root_offload_copy = (*mpm.get_root_module() == m) ? this->offload_copy : false;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
auto op = ins->get_operator(); auto op = ins->get_operator();
...@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const ...@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const
continue; continue;
auto s = ins->get_shape(); auto s = ins->get_shape();
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins)) if(not root_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{ {
auto out_param = m.add_parameter(mod_output_names[ins], s); auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param); m.replace_instruction(ins, out_param);
......
...@@ -39,8 +39,6 @@ ...@@ -39,8 +39,6 @@
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <unordered_set> #include <unordered_set>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -1487,13 +1485,10 @@ struct find_split_transpose ...@@ -1487,13 +1485,10 @@ struct find_split_transpose
void simplify_algebra::apply(module& m) const void simplify_algebra::apply(module& m) const
{ {
size_t trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{});
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(trace, match::find_matches(m,
m,
find_inner_broadcast{}, find_inner_broadcast{},
find_dot_broadcast{}, find_dot_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target.hpp>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const target& t) { v["name"] = t.name(); }
void migraphx_from_value(const value& v, target& t)
{
t = make_target(v.at("name").to<std::string>());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES}) ...@@ -88,8 +88,6 @@ foreach(LIBRARY ${OpenMP_CXX_LIBRARIES})
endif() endif()
endforeach() endforeach()
target_link_libraries(migraphx_all_targets INTERFACE migraphx_cpu)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_cpu TARGETS migraphx_cpu
INCLUDE INCLUDE
......
...@@ -170,7 +170,11 @@ struct compile_plan ...@@ -170,7 +170,11 @@ struct compile_plan
if(results.empty()) if(results.empty())
MIGRAPHX_THROW("No configs to tune"); MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1) if(results.size() == 1)
{
if(not results.front().has_value())
MIGRAPHX_THROW("No configs to tune");
return *results.front(); return *results.front();
}
if(not config) if(not config)
MIGRAPHX_THROW("Multiple kernels without config"); MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
...@@ -185,6 +189,7 @@ struct compile_plan ...@@ -185,6 +189,7 @@ struct compile_plan
.first; .first;
}); });
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i)); pc.insert(preop.name(), config->problem, config->solutions.at(i));
if(not results[i].has_value()) if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation."); MIGRAPHX_THROW("No valid tuned compilation.");
......
...@@ -83,10 +83,23 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -83,10 +83,23 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back();
auto k = a.lens().back();
// Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{
if(m % 4 != 0)
return false;
if(n % 4 != 0)
return false;
if(k % 4 != 0)
return false;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK // to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy // To-do: Investigate a more precise strategy
return a.lens().back() <= 2048; return k <= 2048;
} }
struct find_ck_gemm_pointwise struct find_ck_gemm_pointwise
......
...@@ -139,7 +139,8 @@ struct find_mlir_op ...@@ -139,7 +139,8 @@ struct find_mlir_op
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), is_mlir_conv()).bind("gemm_based_op")); match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
...@@ -190,6 +191,68 @@ struct find_mlir_op ...@@ -190,6 +191,68 @@ struct find_mlir_op
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg"};
const std::initializer_list<std::string> fp_only_ops = {"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid"
"softmax",
"tanh"};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type && contains(no_bool_ops, name))
return true;
if(is_float && contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float && name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -197,31 +260,12 @@ struct find_mlir_op ...@@ -197,31 +260,12 @@ struct find_mlir_op
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators // Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not contains({"@literal", return not is_pointwise_op_supported_by_mlir(i);
"@param",
"@return",
"convolution",
"quant_convolution",
"dot",
"add",
"relu",
"dequantizelinear",
"quantizelinear",
"mul"},
i.name());
}))
return;
// Only fuse with fp32/fp16/int8/int32
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::int32_type},
i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
namespace gpu { namespace gpu {
...@@ -45,7 +45,7 @@ struct lowering ...@@ -45,7 +45,7 @@ struct lowering
context* ctx; context* ctx;
bool offload_copy; bool offload_copy;
std::string name() const { return "gpu::lowering"; } std::string name() const { return "gpu::lowering"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -66,7 +66,7 @@ ${preamble} ...@@ -66,7 +66,7 @@ ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) MIGRAPHX_GLOBAL void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<${solution}, ${blocks_per_batch}>(xs...); ck_gemm<${solution}, ${blocks_per_batch}>(xs...);
...@@ -266,7 +266,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -266,7 +266,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
s = shape{s.type(), {m1, m2}}; s = shape{s.type(), {m1, m2}};
} }
std::vector<std::string> names() const { return {"gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
static bool standard_batch(const shape& s) static bool standard_batch(const shape& s)
{ {
...@@ -419,9 +419,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -419,9 +419,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
auto v = create_settings(ins, op); auto v = create_settings(ins, op);
if(solution.is_null()) if(not solution.is_null())
v["tuning_value"] = 4;
else
v["tuning_value"] = solution; v["tuning_value"] = solution;
return {compile_op(ctx, shapes, v), return {compile_op(ctx, shapes, v),
[=](module& m, instruction_ref ins2, const operation& code_object) { [=](module& m, instruction_ref ins2, const operation& code_object) {
......
...@@ -47,7 +47,7 @@ ${preamble} ...@@ -47,7 +47,7 @@ ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) MIGRAPHX_GLOBAL void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(${concat_args})(${post}, y, xs...); concat<${axis}>(${concat_args})(${post}, y, xs...);
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output) MIGRAPHX_GLOBAL void gather_kernel(void* in_data, void* in_indices, void* output)
{ {
make_tensors()(in_data, in_indices, output)([](auto&&... xs) { make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...); gather<${axis}>(xs...);
......
...@@ -44,7 +44,7 @@ namespace migraphx { ...@@ -44,7 +44,7 @@ namespace migraphx {
extern "C" { extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output) MIGRAPHX_GLOBAL void gathernd_kernel(void* in_data, void* in_indices, void* output)
{ {
make_tensors()(in_data, in_indices, output)([](auto&&... xs) { make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS})); auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
......
...@@ -48,7 +48,7 @@ namespace migraphx { ...@@ -48,7 +48,7 @@ namespace migraphx {
${preamble} ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) MIGRAPHX_GLOBAL void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, ${eps}, xs...); ${layernorm}<${axis}>(${post}, ${eps}, xs...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment