Commit 712f6134 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch and resolve merge conflicts

parents 4a39a0f7 b20e3d4d
...@@ -26,6 +26,7 @@ add_library(migraphx ...@@ -26,6 +26,7 @@ add_library(migraphx
eliminate_pad.cpp eliminate_pad.cpp
env.cpp env.cpp
file_buffer.cpp file_buffer.cpp
fuse_pointwise.cpp
generate.cpp generate.cpp
inline_module.cpp inline_module.cpp
insert_pad.cpp insert_pad.cpp
...@@ -130,9 +131,11 @@ register_migraphx_ops( ...@@ -130,9 +131,11 @@ register_migraphx_ops(
multibroadcast multibroadcast
multinomial multinomial
neg neg
nonmaxsuppression
nonzero nonzero
outline outline
pad pad
pointwise
pooling pooling
pow pow
prefix_scan_sum prefix_scan_sum
...@@ -153,6 +156,7 @@ register_migraphx_ops( ...@@ -153,6 +156,7 @@ register_migraphx_ops(
rnn_last_cell_output rnn_last_cell_output
rnn_last_hs_output rnn_last_hs_output
rnn_var_sl_last_output rnn_var_sl_last_output
roialign
round round
rsqrt rsqrt
scalar scalar
...@@ -198,6 +202,9 @@ target_link_libraries(migraphx PRIVATE -ldl) ...@@ -198,6 +202,9 @@ target_link_libraries(migraphx PRIVATE -ldl)
target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>) target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
find_package(Threads)
target_link_libraries(migraphx PUBLIC Threads::Threads)
find_package(msgpack REQUIRED) find_package(msgpack REQUIRED)
target_link_libraries(migraphx PRIVATE msgpackc-cxx) target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests # Make this available to the tests
...@@ -235,6 +242,7 @@ rocm_export_targets( ...@@ -235,6 +242,7 @@ rocm_export_targets(
TARGETS migraphx::migraphx migraphx_all_targets TARGETS migraphx::migraphx migraphx_all_targets
NAMESPACE migraphx:: NAMESPACE migraphx::
DEPENDS DEPENDS
Threads
${PACKAGE_DEPENDS} ${PACKAGE_DEPENDS}
) )
......
...@@ -3,7 +3,7 @@ add_library(migraphx_c ...@@ -3,7 +3,7 @@ add_library(migraphx_c
api.cpp api.cpp
) )
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 2.0) rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c) rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
......
This diff is collapsed.
...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); ...@@ -209,7 +209,8 @@ migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation, migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name, const char* name,
const char* attributes); const char* attributes,
...);
migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation); migraphx_status migraphx_operation_name(char* out, size_t out_size, migraphx_operation_t operation);
......
...@@ -252,7 +252,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -252,7 +252,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout; const size_t* pout;
size_t pout_size; size_t pout_size;
call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr()); call(&migraphx_shape_lengths, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size); return {pout, pout + pout_size};
} }
std::vector<size_t> strides() const std::vector<size_t> strides() const
...@@ -260,7 +260,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -260,7 +260,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
const size_t* pout; const size_t* pout;
size_t pout_size; size_t pout_size;
call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr()); call(&migraphx_shape_strides, &pout, &pout_size, this->get_handle_ptr());
return std::vector<size_t>(pout, pout + pout_size); return {pout, pout + pout_size};
} }
migraphx_shape_datatype_t type() const migraphx_shape_datatype_t type() const
...@@ -312,7 +312,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -312,7 +312,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr()); call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return shape(pout); return {pout};
} }
char* data() const char* data() const
...@@ -325,9 +325,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -325,9 +325,8 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
/// Generate an argument using random data /// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0) static argument generate(shape ps, size_t pseed = 0)
{ {
return argument( return {make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed),
make<migraphx_argument>(&migraphx_argument_generate, ps.get_handle_ptr(), pseed), own{}};
own{});
} }
friend bool operator==(const argument& px, const argument& py) friend bool operator==(const argument& px, const argument& py)
...@@ -378,7 +377,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -378,7 +377,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname); call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return shape(pout); return {pout};
} }
std::vector<const char*> names() const std::vector<const char*> names() const
...@@ -438,7 +437,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -438,7 +437,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return argument(pout); return {pout};
} }
struct iterator_read struct iterator_read
...@@ -449,7 +448,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -449,7 +448,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx); call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout); return {pout};
} }
}; };
}; };
...@@ -471,7 +470,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -471,7 +470,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return shape(pout); return {pout};
} }
struct iterator_read struct iterator_read
...@@ -481,7 +480,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -481,7 +480,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx); call(&migraphx_shapes_get, &pout, self, pidx);
return shape(pout); return {pout};
} }
}; };
}; };
...@@ -599,16 +598,17 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -599,16 +598,17 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); } operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
operation(const char* name, const char* attributes = nullptr) template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{ {
this->make_handle(&migraphx_operation_create, name, attributes); this->make_handle(&migraphx_operation_create, name, attributes, xs...);
} }
std::string name() std::string name()
{ {
std::array<char, 1024> out_name; std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr()); call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return std::string(out_name.data()); return {out_name.data()};
} }
}; };
......
...@@ -212,7 +212,9 @@ def program(h): ...@@ -212,7 +212,9 @@ def program(h):
@auto_handle() @auto_handle()
def operation(h): def operation(h):
h.constructor('create', h.constructor('create',
api.params(name='const char*', attributes='const char*'), api.params(name='const char*',
attributes='const char*',
vlist='...'),
fname='migraphx::create_op') fname='migraphx::create_op')
h.method('name', returns='std::string') h.method('name', returns='std::string')
......
...@@ -155,5 +155,13 @@ std::vector<argument> argument::get_sub_objects() const ...@@ -155,5 +155,13 @@ std::vector<argument> argument::get_sub_objects() const
return result; return result;
} }
argument argument::element(std::size_t i) const
{
assert(this->get_shape().sub_shapes().empty());
auto idx = this->get_shape().index(i);
auto offset = this->get_shape().type_size() * idx;
return argument{shape{this->get_shape().type()}, this->data() + offset};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp> #include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -26,17 +27,19 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate ...@@ -26,17 +27,19 @@ cpp_generator::function::set_body(const module& m, const cpp_generator::generate
{ {
names[ins] = names[ins] =
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter; migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter;
continue;
} }
if(ins->name() == "@return") else if(ins->name() == "@return")
{ {
assert(ins->inputs().size() == 1); assert(ins->inputs().size() == 1);
return_ins = ins->inputs().front(); return_ins = ins->inputs().front();
} }
else
{
std::string n = "z" + std::to_string(names.size()); std::string n = "z" + std::to_string(names.size());
names[ins] = n; names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n"; ss << "auto " << n << " = " << g(ins, names) << ";\n";
} }
}
ss << "return " << names.at(return_ins) << ";\n"; ss << "return " << names.at(return_ins) << ";\n";
body = ss.str(); body = ss.str();
return *this; return *this;
...@@ -49,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m) ...@@ -49,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m)
cpp_generator::function& cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse) cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{ {
this->params.clear();
auto pmap = m.get_parameter_shapes(); auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end()); std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform( std::transform(
...@@ -61,11 +65,30 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str ...@@ -61,11 +65,30 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str
return *this; return *this;
} }
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
this->return_type = "auto";
return *this;
}
struct cpp_generator_impl struct cpp_generator_impl
{ {
std::stringstream fs{}; std::stringstream fs{};
std::size_t function_count = 0; std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr; std::function<std::string(std::string)> fmap = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
}; };
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {} cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
...@@ -81,12 +104,28 @@ cpp_generator::~cpp_generator() noexcept = default; ...@@ -81,12 +104,28 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; } void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
}
std::string cpp_generator::generate_point_op(const operation& op, std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args) const std::vector<std::string>& args)
{ {
auto v = op.to_value(); auto v = op.to_value();
return interpolate_string(op.attributes()["point_op"].to<std::string>(), std::string code;
[&](auto start, auto last) -> std::string { if(contains(impl->point_op_map, op.name()))
{
code = impl->point_op_map.at(op.name());
}
else
{
auto attributes = op.attributes();
if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
code = attributes["point_op"].to<std::string>();
}
return interpolate_string(code, [&](auto start, auto last) -> std::string {
auto key = trim({start, last}); auto key = trim({start, last});
if(key.empty()) if(key.empty())
MIGRAPHX_THROW("Empty parameter"); MIGRAPHX_THROW("Empty parameter");
...@@ -120,7 +159,12 @@ std::string cpp_generator::str() const { return impl->fs.str(); } ...@@ -120,7 +159,12 @@ std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m) cpp_generator::function cpp_generator::generate_module(const module& m)
{ {
function f; function f;
f.set_name(m.name()).set_types(m).set_body( auto name = transform_string(m.name(), [](char c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return '_';
});
f.set_name(name).set_types(m).set_body(
m, [&](instruction_ref ins, const auto& names) -> std::string { m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal") if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" + return shape::cpp_type(ins->get_shape().type()) + "(" +
...@@ -130,7 +174,6 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -130,7 +174,6 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(args), std::back_inserter(args),
[&](auto i) { return names.at(i); }); [&](auto i) { return names.at(i); });
auto s = this->generate_point_op(ins->get_operator(), args);
return this->generate_point_op(ins->get_operator(), args); return this->generate_point_op(ins->get_operator(), args);
}); });
return f; return f;
...@@ -139,6 +182,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -139,6 +182,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
std::string cpp_generator::create_function(const cpp_generator::function& f) std::string cpp_generator::create_function(const cpp_generator::function& f)
{ {
impl->function_count++; impl->function_count++;
if(not f.tparams.empty())
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name; std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name; impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '('; char delim = '(';
......
...@@ -6,6 +6,7 @@ add_executable(driver ...@@ -6,6 +6,7 @@ add_executable(driver
resnet50.cpp resnet50.cpp
inceptionv3.cpp inceptionv3.cpp
alexnet.cpp alexnet.cpp
marker_roctx.cpp
) )
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver) set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility # Copy driver for backwards compatibility
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/rank.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
...@@ -106,10 +107,22 @@ struct argument_parser ...@@ -106,10 +107,22 @@ struct argument_parser
return to_string_range(x); return to_string_range(x);
} }
template <class T>
auto as_string_value(rank<1>, const T& x) -> decltype(to_string(x))
{
return to_string(x);
}
template <class T>
std::string as_string_value(rank<0>, const T&)
{
throw std::runtime_error("Can't convert to string");
}
template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})> template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})>
std::string as_string_value(const T& x) std::string as_string_value(const T& x)
{ {
return to_string(x); return as_string_value(rank<1>{}, x);
} }
template <class T, class... Fs> template <class T, class... Fs>
...@@ -124,8 +137,9 @@ struct argument_parser ...@@ -124,8 +137,9 @@ struct argument_parser
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>(); arg.type = migraphx::get_type_name<T>();
arg.default_value = as_string_value(x);
migraphx::each_args([&](auto f) { f(x, arg); }, fs...); migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x);
} }
template <class... Fs> template <class... Fs>
......
#include "verify.hpp"
#include "argument_parser.hpp" #include "argument_parser.hpp"
#include "command.hpp" #include "command.hpp"
#include "verify.hpp" #include "precision.hpp"
#include "perf.hpp" #include "perf.hpp"
#include "models.hpp" #include "models.hpp"
#include "marker_roctx.hpp"
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
...@@ -287,14 +289,12 @@ struct compiler_target ...@@ -287,14 +289,12 @@ struct compiler_target
struct compiler struct compiler
{ {
static const int q_fp16 = 1;
static const int q_int8 = 2;
loader l; loader l;
program_params parameters; program_params parameters;
compiler_target ct; compiler_target ct;
bool offload_copy = false; bool offload_copy = false;
bool fast_math = true; bool fast_math = true;
int quantize = 0; precision quantize = precision::fp32;
std::vector<std::string> fill0; std::vector<std::string> fill0;
std::vector<std::string> fill1; std::vector<std::string> fill1;
...@@ -311,8 +311,8 @@ struct compiler ...@@ -311,8 +311,8 @@ struct compiler
{"--disable-fast-math"}, {"--disable-fast-math"},
ap.help("Disable fast math optimization"), ap.help("Disable fast math optimization"),
ap.set_value(false)); ap.set_value(false));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(q_fp16)); ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(q_int8)); ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
} }
auto params(const program& p) { return parameters.generate(p, ct.get_target(), offload_copy); } auto params(const program& p) { return parameters.generate(p, ct.get_target(), offload_copy); }
...@@ -324,11 +324,11 @@ struct compiler ...@@ -324,11 +324,11 @@ struct compiler
if(p.is_compiled()) if(p.is_compiled())
return p; return p;
auto t = ct.get_target(); auto t = ct.get_target();
if(quantize == q_fp16) if(quantize == precision::fp16)
{ {
quantize_fp16(p); quantize_fp16(p);
} }
else if(quantize == q_int8) else if(quantize == precision::int8)
{ {
quantize_int8(p, t, {params(p)}); quantize_int8(p, t, {params(p)});
} }
...@@ -376,6 +376,7 @@ struct verify : command<verify> ...@@ -376,6 +376,7 @@ struct verify : command<verify>
bool reduce = false; bool reduce = false;
bool offload_copy = false; bool offload_copy = false;
bool fast_math = true; bool fast_math = true;
precision quantize = precision::fp32;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
l.parse(ap); l.parse(ap);
...@@ -395,6 +396,7 @@ struct verify : command<verify> ...@@ -395,6 +396,7 @@ struct verify : command<verify>
ap.help("Verify each instruction"), ap.help("Verify each instruction"),
ap.set_value(true)); ap.set_value(true));
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true)); ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
} }
void run() void run()
...@@ -411,15 +413,15 @@ struct verify : command<verify> ...@@ -411,15 +413,15 @@ struct verify : command<verify>
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, options, tolerance); verify_instructions(p, t, options, quantize, tolerance);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, options, m, tolerance); verify_reduced_program(p, t, options, quantize, m, tolerance);
} }
else else
{ {
verify_program(l.file, p, t, options, m, tolerance); verify_program(l.file, p, t, options, quantize, m, tolerance);
} }
} }
}; };
...@@ -479,7 +481,24 @@ struct perf : command<perf> ...@@ -479,7 +481,24 @@ struct perf : command<perf>
std::cout << "Allocating params ... " << std::endl; std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p); auto m = c.params(p);
std::cout << "Running performance report ... " << std::endl; std::cout << "Running performance report ... " << std::endl;
p.perf_report(std::cout, n, m); p.perf_report(std::cout, n, m, c.l.batch);
}
};
struct roctx : command<roctx>
{
compiler c;
void parse(argument_parser& ap) { c.parse(ap); }
void run()
{
std::cout << "Compiling ... " << std::endl;
auto p = c.compile();
std::cout << "Allocating params ... " << std::endl;
auto m = c.params(p);
std::cout << "rocTX:\tLoading rocTX library..." << std::endl;
auto rtx = create_marker_roctx();
p.mark(m, std::move(rtx));
} }
}; };
......
#include "marker_roctx.hpp"
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
class marker_roctx
{
std::function<void(const char*)> sym_roctx_mark;
std::function<uint64_t(const char*)> sym_roctx_range_start;
std::function<void(uint64_t)> sym_roctx_range_stop;
std::function<int(const char*)> sym_roctx_range_push;
std::function<int()> sym_roctx_range_pop;
uint64_t range_id;
public:
marker_roctx()
{
dynamic_loader lib = migraphx::dynamic_loader{"libroctx64.so"};
sym_roctx_mark = lib.get_function<void(const char*)>("roctxMarkA");
sym_roctx_range_start = lib.get_function<uint64_t(const char*)>("roctxRangeStartA");
sym_roctx_range_stop = lib.get_function<void(uint64_t)>("roctxRangeStop");
sym_roctx_range_push = lib.get_function<int(const char*)>("roctxRangePushA");
sym_roctx_range_pop = lib.get_function<int()>("roctxRangePop");
sym_roctx_mark("rocTX marker created.");
}
void mark_start(instruction_ref ins_ref)
{
std::string text = "Marker start: " + ins_ref->name();
sym_roctx_range_push(text.c_str());
}
void mark_stop(instruction_ref) { sym_roctx_range_pop(); }
void mark_start(const program&) { range_id = sym_roctx_range_start("0"); }
void mark_stop(const program&) { sym_roctx_range_stop(range_id); }
};
marker create_marker_roctx() { return marker_roctx(); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
#define MIGRAPHX_GUARD_RTGLIB_MARKER_ROCTX_HPP
#include <migraphx/marker.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
marker create_marker_roctx();
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_PRECISION_HPP
#define MIGRAPHX_GUARD_RTGLIB_PRECISION_HPP
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
enum class precision
{
fp32,
fp16,
int8
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
#endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/verify_args.hpp> #include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
...@@ -19,9 +20,16 @@ std::vector<argument> run_ref(program p, const parameter_map& inputs) ...@@ -19,9 +20,16 @@ std::vector<argument> run_ref(program p, const parameter_map& inputs)
return out; return out;
} }
std::vector<argument> std::vector<argument> run_target(program p,
run_target(program p, const target& t, const compile_options& options, const parameter_map& inputs) const target& t,
const compile_options& options,
precision quantize,
const parameter_map& inputs)
{ {
if(quantize == precision::fp16)
{
quantize_fp16(p);
}
p.compile(t, options); p.compile(t, options);
parameter_map m; parameter_map m;
...@@ -43,24 +51,24 @@ void verify_program(const std::string& name, ...@@ -43,24 +51,24 @@ void verify_program(const std::string& name,
const program& p, const program& p,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) double tolerance)
{ {
auto x = run_ref(p, inputs); auto x = run_ref(p, inputs);
auto y = run_target(p, t, options, inputs); auto y = run_target(p, t, options, quantize, inputs);
std::size_t output_num = x.size(); std::size_t output_num = x.size();
for(std::size_t i = 0; i < output_num; ++i) for(std::size_t i = 0; i < output_num; ++i)
{ {
verify_args(name, x[i], y[i], tolerance); verify_args(name, x[i], y[i], tolerance);
} }
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
} }
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize,
double tolerance) double tolerance)
{ {
const auto* mm_prog = prog.get_main_module(); const auto* mm_prog = prog.get_main_module();
...@@ -92,7 +100,8 @@ void verify_instructions(const program& prog, ...@@ -92,7 +100,8 @@ void verify_instructions(const program& prog,
{ {
std::cout << "Verify: " << ins.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(ins.name(), p, t, options, create_param_map(p, false), tolerance); verify_program(
ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance);
} }
catch(...) catch(...)
{ {
...@@ -106,6 +115,7 @@ void verify_reduced(program p, ...@@ -106,6 +115,7 @@ void verify_reduced(program p,
int n, int n,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) double tolerance)
{ {
...@@ -114,12 +124,13 @@ void verify_reduced(program p, ...@@ -114,12 +124,13 @@ void verify_reduced(program p,
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << std::endl; std::cout << "Verify: " << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, inputs, tolerance); verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
} }
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) double tolerance)
{ {
...@@ -127,7 +138,7 @@ void verify_reduced_program(const program& p, ...@@ -127,7 +138,7 @@ void verify_reduced_program(const program& p,
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(p, i, t, options, inputs, tolerance); verify_reduced(p, i, t, options, quantize, inputs, tolerance);
} }
} }
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP #define MIGRAPHX_GUARD_RTGLIB_DRIVER_VERIFY_HPP
#include "precision.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
namespace migraphx { namespace migraphx {
...@@ -11,15 +12,18 @@ void verify_program(const std::string& name, ...@@ -11,15 +12,18 @@ void verify_program(const std::string& name,
const program& p, const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 100); double tolerance = 100);
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32,
double tolerance = 80); double tolerance = 80);
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 80); double tolerance = 80);
......
...@@ -45,6 +45,7 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer) ...@@ -45,6 +45,7 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{ {
dlerror();
void* symbol = dlsym(impl->handle.get(), name.c_str()); void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr) if(symbol == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name); MIGRAPHX_THROW("Symbol not found: " + name);
......
...@@ -11,11 +11,13 @@ ...@@ -11,11 +11,13 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inputs) static bool try_compute_shape(instruction_ref ins,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mods)
{ {
try try
{ {
shape new_shape = ins->get_operator().compute_shape(inputs); shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// If the output shape is a standard shape, no need to try its output // If the output shape is a standard shape, no need to try its output
if(new_shape.standard()) if(new_shape.standard())
{ {
...@@ -45,7 +47,7 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp ...@@ -45,7 +47,7 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp
return (arg == ins) ? new_shape : arg->get_shape(); return (arg == ins) ? new_shape : arg->get_shape();
}); });
if(!try_compute_shape(output, input_shapes)) if(!try_compute_shape(output, input_shapes, mods))
{ {
return false; return false;
} }
...@@ -59,10 +61,12 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp ...@@ -59,10 +61,12 @@ static bool try_compute_shape(instruction_ref ins, const std::vector<shape>& inp
return true; return true;
} }
static bool try_compute_shape(instruction_ref ins, const std::vector<instruction_ref>& args) static bool try_compute_shape(instruction_ref ins,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods)
{ {
auto inputs = to_shapes(args); auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs); return try_compute_shape(ins, inputs, mods);
} }
void eliminate_contiguous::apply(module& p) const void eliminate_contiguous::apply(module& p) const
...@@ -82,7 +86,7 @@ void eliminate_contiguous::apply(module& p) const ...@@ -82,7 +86,7 @@ void eliminate_contiguous::apply(module& p) const
auto new_args = args; auto new_args = args;
auto prev = arg->inputs().front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args)) if(try_compute_shape(ins, new_args, ins->module_inputs()))
{ {
instruction::replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
......
...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void eliminate_data_type::apply(module& m) const void eliminate_data_type::apply(module& m) const
{ {
static const std::vector<std::string> skip_op_names = { static const std::vector<std::string> skip_op_names = {
"convert", "get_tuple_elem", "if", "loop"}; "convert", "get_tuple_elem", "if", "loop", "roialign"};
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
......
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <iterator>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static literal get_scalar(instruction_ref ins)
{
if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
e.visit_at([&](auto x) { r = literal{x}; });
return r;
}
static void create_pointwise_modules(module_pass_manager& mpm)
{
std::size_t n = 0;
for(auto ins : iterator_for(mpm.get_module()))
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs;
std::size_t i = 0;
for(auto input : ins->inputs())
{
if(contains(param_map, input))
continue;
auto scalar = get_scalar(input);
if(scalar.empty())
{
pointwise_inputs.push_back(input);
param_map[input] =
pm->add_parameter("x" + std::to_string(i), shape{input->get_shape().type()});
i++;
}
else
{
param_map[input] = pm->add_literal(scalar);
}
}
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return param_map[input]; });
auto r = pm->add_instruction(ins->get_operator(), inputs);
pm->add_return({r});
mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm});
}
}
static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
instruction_ref output)
{
assert(contains(output->inputs(), ins));
module_ref pm = ins->module_inputs().at(0);
module_ref xm = output->module_inputs().at(0);
auto last = std::prev(pm->end());
assert(last->name() == "@return");
assert(last->inputs().size() == 1);
assert(pm->get_parameter_names().size() == ins->inputs().size());
assert(xm->get_parameter_names().size() == output->inputs().size());
std::vector<instruction_ref> inputs = ins->inputs();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::unordered_map<instruction_ref, instruction_ref> input_map;
// Copy inputs to input_map
for(auto i : range(inputs.size()))
{
auto input = inputs[i];
auto param = pm->get_parameter("x" + std::to_string(i));
assert(param != pm->end());
input_map[input] = param;
}
// Add the new parameter and additional inputs
for(auto i : range(output->inputs().size()))
{
auto input = output->inputs()[i];
auto param = xm->get_parameter("x" + std::to_string(i));
assert(param != xm->end());
if(input == ins)
{
map_ins[param] = last->inputs().front();
input_map[input] = map_ins[param];
}
// Avoid duplicate paramter inputs
else if(contains(input_map, input))
{
map_ins[param] = input_map[input];
}
else
{
map_ins[param] =
pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()});
inputs.push_back(input);
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins));
return inputs;
}
static bool find_pointwise_modules(module& m)
{
bool changed = false;
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty() and ins != last)
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end())
continue;
auto input = *it;
auto new_inputs = append_pointwise_module(input, ins);
m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs());
m.replace_instruction(ins, input);
m.move_instruction(input, ins);
changed = true;
}
return changed;
}
void fuse_pointwise::apply(module_pass_manager& mpm) const
{
create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 8; i++)
{
if(not find_pointwise_modules(mpm.get_module()))
break;
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment