Unverified Commit 476ed17c authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Merge branch 'develop' into rand_uniform

parents f4f9d711 6f1c947f
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp) add_library(migraphx_py py_loader.cpp)
migraphx_generate_export_header(migraphx_py)
target_include_directories(migraphx_py PRIVATE include) target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx) target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include) rocm_install_targets(TARGETS migraphx_py INCLUDE include)
......
...@@ -26,11 +26,12 @@ ...@@ -26,11 +26,12 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
program load_py(const std::string& filename); MIGRAPHX_PY_EXPORT program load_py(const std::string& filename);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -64,7 +64,7 @@ static dynamic_loader py_lib() ...@@ -64,7 +64,7 @@ static dynamic_loader py_lib()
return lib; return lib;
} }
program load_py(const std::string& filename) MIGRAPHX_PY_EXPORT program load_py(const std::string& filename)
{ {
static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py"); static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py");
return f(filename); return f(filename);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -62,12 +63,9 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -62,12 +63,9 @@ void apply_quantizelinear(module& m, instruction_ref ins)
min_quant = qt.min(); min_quant = qt.min();
}); });
auto s = add_zero_point->get_shape(); auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant); auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
std::vector<int> max_data(s.elements(), max_quant); auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto min_arg = m.add_literal(literal(s, min_data)); auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction( m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
} }
......
...@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) ...@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
}; };
}; };
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2); return (dots >= 2 or convs >= 2 or qdots >= 2);
} }
struct find_conv_dot_horiz_fusion struct find_conv_dot_horiz_fusion
...@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto pred = [](auto i, auto j) { auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator()) if(i->get_operator() != j->get_operator())
return false; return false;
if(not contains({"dot", "convolution"}, i->name())) if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true; return true;
auto x = i->inputs()[1]->get_shape().lens(); auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens(); auto y = j->inputs()[1]->get_shape().lens();
...@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return false; return false;
// Check that non-axes match // Check that non-axes match
int axis = 1; int axis = 1;
if(i->name() == "dot") if(i->name() == "dot" or i->name() == "quant_dot")
{ {
axis = x.size() - 1; axis = x.size() - 1;
} }
...@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if(std::distance(start, last) < 2) if(std::distance(start, last) < 2)
return; return;
auto&& name = (*start)->name(); auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name)) if(not contains({"quant_dot", "dot", "convolution"}, name))
return; return;
auto op = (*start)->get_operator(); auto op = (*start)->get_operator();
int group = 1; int group = 1;
...@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); }); start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1; int axis = 1;
int concat_axis = 0; int concat_axis = 0;
if(name == "dot") if(name == "dot" or name == "quant_dot")
{ {
axis = int(args.front()->get_shape().lens().size() - 1); axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis; concat_axis = axis;
......
...@@ -48,6 +48,7 @@ struct sqlite_impl ...@@ -48,6 +48,7 @@ struct sqlite_impl
template <class F> template <class F>
void exec(const char* sql, F f) void exec(const char* sql, F f)
{ {
// cppcheck-suppress constParameterPointer
auto callback = [](void* obj, auto... xs) -> int { auto callback = [](void* obj, auto... xs) -> int {
try try
{ {
......
...@@ -61,7 +61,7 @@ namespace cpu { ...@@ -61,7 +61,7 @@ namespace cpu {
std::string target::name() const { return "cpu"; } std::string target::name() const { return "cpu"; }
// cppcheck-suppress constParameter // cppcheck-suppress constParameterReference
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
......
...@@ -81,6 +81,12 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) ...@@ -81,6 +81,12 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
dim3 nthreads(local); dim3 nthreads(local);
// cppcheck-suppress UseDeviceLaunch // cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f); hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f);
hipError_t kernel_launch_status = hipGetLastError();
if(kernel_launch_status != hipSuccess)
{
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " +
std::string(hipGetErrorString(kernel_launch_status)));
}
}; };
} }
......
...@@ -124,7 +124,7 @@ void nary_broadcast_vec_impl( ...@@ -124,7 +124,7 @@ void nary_broadcast_vec_impl(
buffer[i] = binput.data()[i]; buffer[i] = binput.data()[i];
} }
__syncthreads(); __syncthreads();
auto* bp = as_pointer(buffer); const auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
...@@ -219,7 +219,7 @@ void nary_double_broadcast_vec_impl( ...@@ -219,7 +219,7 @@ void nary_double_broadcast_vec_impl(
buffer[i + bdim_vec_len] = binput2.data()[i]; buffer[i + bdim_vec_len] = binput2.data()[i];
} }
__syncthreads(); __syncthreads();
auto* bp = as_pointer(buffer); const auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
......
...@@ -72,12 +72,12 @@ struct hip_heap_vector ...@@ -72,12 +72,12 @@ struct hip_heap_vector
index_int l = 2 * index + 1; index_int l = 2 * index + 1;
index_int r = 2 * index + 2; index_int r = 2 * index + 2;
if(l < n && compare(data[data_index(l)], data[data_index(index)])) if(l < n and compare(data[data_index(l)], data[data_index(index)]))
{ {
index = l; index = l;
} }
if(r < n && compare(data[data_index(r)], data[data_index(index)])) if(r < n and compare(data[data_index(r)], data[data_index(index)]))
{ {
index = r; index = r;
if(compare(data[data_index(l)], data[data_index(r)])) if(compare(data[data_index(l)], data[data_index(r)]))
......
...@@ -31,20 +31,6 @@ namespace migraphx { ...@@ -31,20 +31,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class HipDeviceProp>
std::string get_arch_name(rank<0>, const HipDeviceProp& props)
{
return "gfx" + std::to_string(props.gcnArch);
}
template <class HipDeviceProp>
auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName))
{
return std::string(props.gcnArchName);
}
std::string get_arch_name(const hipDeviceProp_t& props) { return get_arch_name(rank<1>{}, props); }
int get_device_id() int get_device_id()
{ {
int device; int device;
...@@ -60,7 +46,7 @@ std::string get_device_name() ...@@ -60,7 +46,7 @@ std::string get_device_name()
auto status = hipGetDeviceProperties(&props, get_device_id()); auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties"); MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(props); return props.gcnArchName;
} }
} // namespace gpu } // namespace gpu
......
...@@ -86,7 +86,7 @@ struct mlir_op ...@@ -86,7 +86,7 @@ struct mlir_op
size_t param_cnt = 0; size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names(); std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
for(std::string param_name : names) for(const std::string& param_name : names)
{ {
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++]; ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
} }
...@@ -210,7 +210,8 @@ struct find_mlir_op ...@@ -210,7 +210,8 @@ struct find_mlir_op
return false; return false;
} }
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"}; const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {"convolution", const std::initializer_list<std::string> no_bool_ops = {
"convolution",
"quant_convolution", "quant_convolution",
"dot", "dot",
"quant_dot", "quant_dot",
...@@ -225,27 +226,31 @@ struct find_mlir_op ...@@ -225,27 +226,31 @@ struct find_mlir_op
"quantizelinear", "quantizelinear",
"dequantizelinear", "dequantizelinear",
"abs", "abs",
"neg"}; "neg",
const std::initializer_list<std::string> fp_only_ops = {"ceil", };
const std::initializer_list<std::string> fp_only_ops = {
"ceil",
"erf", "erf",
"exp", "exp",
"floor", "floor",
"log", "log",
"recip", "recip",
"rsqrt", "rsqrt",
"sigmoid" // There are bugs in MLIR right now for models using sigmoid so disable it for now
// "sigmoid",
"softmax", "softmax",
"tanh"}; "tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type && contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true; return true;
if(is_float && contains(fp_only_ops, name)) if(is_float and contains(fp_only_ops, name))
return true; return true;
// Only conversions between floating types are known to be unambigiously // Only conversions between floating types are known to be unambigiously
// supported. // supported.
if(is_float && name == "convert") if(is_float and name == "convert")
{ {
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
......
...@@ -46,13 +46,7 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); ...@@ -46,13 +46,7 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
struct hip_device struct hip_device
{ {
hip_device() hip_device() : device_props{} { add_stream(); }
{
device_props.gcnArchName[0] = '\0';
device_props.gcnArch = 0;
device_props.multiProcessorCount = 0;
add_stream();
}
hip_device(std::size_t id, std::size_t n) : device_id(id) hip_device(std::size_t id, std::size_t n) : device_id(id)
{ {
...@@ -171,7 +165,7 @@ struct hip_device ...@@ -171,7 +165,7 @@ struct hip_device
std::size_t stream_id() const { return current_stream; } std::size_t stream_id() const { return current_stream; }
std::string get_device_name() const { return get_arch_name(device_props); } std::string get_device_name() const { return device_props.gcnArchName; }
std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); } std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); }
......
...@@ -33,8 +33,6 @@ namespace migraphx { ...@@ -33,8 +33,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT std::string get_arch_name(const hipDeviceProp_t& props);
MIGRAPHX_GPU_EXPORT std::string get_device_name(); MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
......
...@@ -92,7 +92,7 @@ struct hip_sync_stream ...@@ -92,7 +92,7 @@ struct hip_sync_stream
return inputs.front(); return inputs.front();
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(const context& ctx, const shape&, const std::vector<argument>& args) const
{ {
gpu_sync(ctx); gpu_sync(ctx);
if(args.empty()) if(args.empty())
......
...@@ -37,7 +37,7 @@ struct module; ...@@ -37,7 +37,7 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m); MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m);
MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& ctx, MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& migraphx_ctx,
module m, module m,
const std::vector<instruction_ref>& inputs, const std::vector<instruction_ref>& inputs,
const value& solution); const value& solution);
...@@ -47,7 +47,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m, ...@@ -47,7 +47,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
code_object_op co, code_object_op co,
const std::vector<instruction_ref>& inputs); const std::vector<instruction_ref>& inputs);
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(module m, MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m,
const std::vector<shape>& inputs); const std::vector<shape>& inputs);
} // namespace gpu } // namespace gpu
......
...@@ -300,7 +300,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -300,7 +300,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto rank = a_shape.lens().size(); // cppcheck-suppress unreadVariable
auto rank = a_shape.ndim();
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
......
...@@ -37,7 +37,7 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -37,7 +37,7 @@ struct mlir_compiler : compiler<mlir_compiler>
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
compiler_replace compiler_replace
compile(context& ctx, instruction_ref ins, const operation&, const value& solution) const compile(const context& ctx, instruction_ref ins, const operation&, const value& solution) const
{ {
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
assert(smod->get_parameter_names().size() == ins->inputs().size() - 1); assert(smod->get_parameter_names().size() == ins->inputs().size() - 1);
...@@ -52,14 +52,16 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -52,14 +52,16 @@ struct mlir_compiler : compiler<mlir_compiler>
}}; }};
} }
optional<tuning_config> optional<tuning_config> get_tuning_config(const context& ctx,
get_tuning_config(context&, instruction_ref ins, const operation&, bool exhaustive) const instruction_ref ins,
const operation&,
bool exhaustive) const
{ {
if(not exhaustive) if(not exhaustive)
return nullopt; return nullopt;
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(*smod, shapes); return get_tuning_config_mlir(ctx, *smod, shapes);
} }
}; };
......
...@@ -36,7 +36,10 @@ ...@@ -36,7 +36,10 @@
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck
#ifndef CPPCHECK
#undef MIGRAPHX_MLIR #undef MIGRAPHX_MLIR
#endif
#else #else
#include <mlir-c/RegisterRocMLIR.h> #include <mlir-c/RegisterRocMLIR.h>
#endif #endif
...@@ -173,12 +176,6 @@ std::string mlir_print(F f, T x) ...@@ -173,12 +176,6 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
bool has_xdlops(const std::string& target_arch)
{
const auto device_name = trim(split_string(target_arch, ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
}
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
...@@ -513,7 +510,8 @@ struct mlir_program ...@@ -513,7 +510,8 @@ struct mlir_program
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)}, ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", sym_name}, {"sym_name", sym_name},
{"kernel", std::string("mixr")}, {"kernel", std::string("mixr")},
{"arch", target_arch}}); {"arch", target_arch},
{"num_cu", num_cu}});
ops.add_region(std::move(region)); ops.add_region(std::move(region));
insert(body, std::move(ops)); insert(body, std::move(ops));
...@@ -596,9 +594,6 @@ struct mlir_program ...@@ -596,9 +594,6 @@ struct mlir_program
{ {
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
if(has_xdlops(target_arch))
ops.add_attributes({{"xdlopsV2", true}});
} }
std::vector<MlirValue> inputs; std::vector<MlirValue> inputs;
...@@ -647,7 +642,12 @@ struct mlir_program ...@@ -647,7 +642,12 @@ struct mlir_program
return op; return op;
} }
void find_target() { target_arch = get_device_name(); } void set_gpu_properties(const context& migraphx_ctx)
{
const auto& device = migraphx_ctx.get_current_device();
target_arch = device.get_device_name();
num_cu = device.get_cu_count();
}
std::pair<std::size_t, std::size_t> get_launch_params() const std::pair<std::size_t, std::size_t> get_launch_params() const
{ {
...@@ -661,7 +661,7 @@ struct mlir_program ...@@ -661,7 +661,7 @@ struct mlir_program
value::binary get_binary() const value::binary get_binary() const
{ {
int size = 0; size_t size = 0;
mlirGetBinary(mmodule.get(), &size, nullptr); mlirGetBinary(mmodule.get(), &size, nullptr);
value::binary result(size); value::binary result(size);
if(mlirGetBinary(mmodule.get(), &size, reinterpret_cast<char*>(result.data()))) if(mlirGetBinary(mmodule.get(), &size, reinterpret_cast<char*>(result.data())))
...@@ -669,30 +669,41 @@ struct mlir_program ...@@ -669,30 +669,41 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
void set_tuning(const value& v) void set_tuning(const value& v) MIGRAPHX_TIDY_CONST
{ {
auto str = v.to<std::string>(); const auto* str = v.if_string();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string if(str == nullptr)
std::vector<char> buffer(str.begin(), str.end()); MIGRAPHX_THROW("mlir tuning solutions must be strings");
buffer.push_back(0); if(not mlirRockTuningSetFromStr(mmodule.get(), make_mlir_string_ref(*str)))
if(not mlirRockTuningSetFromStr(mmodule.get(), buffer.data())) MIGRAPHX_THROW("Failed setting tuning key: " + *str);
MIGRAPHX_THROW("Failed setting tuning key: " + str);
} }
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST
{ {
tuning_config tc; tuning_config tc;
run_high_level_pipeline(); run_high_level_pipeline();
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get())}; mlir_tuning_space params{
for(auto i : range(mlirRockTuningGetNumParamsFull(params.get()))) mlirRockTuningSpaceCreate(mmodule.get(), RocmlirTuningParamSetKindFull)};
for(auto i : range(mlirRockTuningGetNumParams(params.get())))
{ {
mlir_tuning_param param{mlirRockTuningParamCreate()}; mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get())) if(not mlirRockTuningParamGet(params.get(), i, param.get()))
MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i)); MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i));
tc.solutions.push_back(std::string{mlirRockTuningGetParamStr(param.get())}); std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> perf_key;
} size_t perf_key_bytes =
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()}; mlirRockTuningParamToString(param.get(), perf_key.data(), perf_key.size());
tc.problem = std::string{mlirRockTuningGetKey(tuning_table.get(), mmodule.get())}; if(perf_key_bytes > perf_key.size())
MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) +
" bytes and thus too long");
tc.solutions.emplace_back(perf_key.begin(), perf_key.begin() + perf_key_bytes);
}
std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key;
size_t tuning_key_bytes =
mlirRockTuningGetKey(mmodule.get(), tuning_key.data(), tuning_key.size());
if(tuning_key_bytes > tuning_key.size())
MIGRAPHX_THROW("Tuning table key was " + std::to_string(tuning_key_bytes) +
" bytes and thus too long");
tc.problem = std::string(tuning_key.begin(), tuning_key.begin() + tuning_key_bytes);
return tc; return tc;
} }
...@@ -700,10 +711,10 @@ struct mlir_program ...@@ -700,10 +711,10 @@ struct mlir_program
// This function appends to tuning cfg file that could be // This function appends to tuning cfg file that could be
// used with rocMLIR tuning scripts. // used with rocMLIR tuning scripts.
void dump_tuning_cfg(const char* prob_config) const void dump_tuning_cfg(const std::string& prob_config) const
{ {
std::string tuning_cfg_path = string_value_of(MIGRAPHX_MLIR_TUNING_CFG{}); std::string tuning_cfg_path = string_value_of(MIGRAPHX_MLIR_TUNING_CFG{});
if(!tuning_cfg_path.empty()) if(not tuning_cfg_path.empty())
{ {
std::vector<std::string> tokens = split_string(prob_config, '\t'); std::vector<std::string> tokens = split_string(prob_config, '\t');
std::string prob = tokens[1]; std::string prob = tokens[1];
...@@ -720,51 +731,66 @@ struct mlir_program ...@@ -720,51 +731,66 @@ struct mlir_program
} }
} }
static mlir_tuning_table create_tuning_table() static std::pair<mlir_tuning_table, bool> load_tuning_table()
{ {
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()}; mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
bool found_table = false;
std::string tuning_db_path = string_value_of(MIGRAPHX_MLIR_TUNING_DB{}); std::string tuning_db_path = string_value_of(MIGRAPHX_MLIR_TUNING_DB{});
if(!tuning_db_path.empty()) if(not tuning_db_path.empty())
{ {
std::ifstream tuning_db_tsv(tuning_db_path); std::ifstream tuning_db_tsv(tuning_db_path);
if(tuning_db_tsv) if(tuning_db_tsv)
{ {
found_table = true;
std::string line; std::string line;
while(std::getline(tuning_db_tsv, line)) while(std::getline(tuning_db_tsv, line))
{ {
std::vector<std::string> tokens = split_string(line, '\t'); std::vector<std::string> tokens = split_string(line, '\t');
std::string arch = tokens[0]; std::string arch = tokens[0];
std::string prob = tokens[1]; std::string num_cu = tokens[1];
std::string perf = tokens[2]; std::string prob = tokens[2];
std::string key = arch.append("\t").append(prob); std::string perf = tokens[3];
mlirRockTuningUpdateTable(tuning_table.get(), key.c_str(), perf.c_str(), 1.0); std::string key = arch.append("\t").append(num_cu).append("\t").append(prob);
mlirRockTuningUpdateTable(tuning_table.get(),
make_mlir_string_ref(key),
make_mlir_string_ref(perf),
1.0);
} }
} }
} }
else else
{ {
found_table = false;
std::cerr std::cerr
<< "WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for " << "WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for "
"optimal performance." "optimal performance."
<< std::endl; << std::endl;
} }
return tuning_table; return std::make_pair(std::move(tuning_table), found_table);
} }
bool get_module_tuned() const bool get_module_tuned() const
{ {
static mlir_tuning_table tuning_table = create_tuning_table(); static std::pair<mlir_tuning_table, bool> tuning_table = load_tuning_table();
// The tuning table as currently implemented is currently not if(not mlirRockTuningSetFromTable(tuning_table.first.get(), mmodule.get()))
// thread safe. This will be fixed in the future. For now, {
// stick a mutex around all tuning table interaction. std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> prob_config;
static std::mutex lock; size_t prob_config_bytes =
std::lock_guard<std::mutex> guard(lock); mlirRockTuningGetKey(mmodule.get(), prob_config.data(), prob_config.size());
if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get())) if(prob_config_bytes >= prob_config.size())
{ {
const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get()); std::cerr << "MLIR tuning key overflowed buffer, needed " << prob_config_bytes
std::stringstream key(prob_config); << " bytes" << std::endl;
std::cerr << "fails to set param on" << prob_config << std::endl; return false;
dump_tuning_cfg(prob_config); }
std::string prob_config_str(prob_config.begin(),
prob_config.begin() + prob_config_bytes);
if(tuning_table.second)
{
std::cerr << "NOTE: MLIR tuning table did not include a key for " << prob_config_str
<< std::endl;
}
dump_tuning_cfg(prob_config_str);
return false; return false;
} }
return true; return true;
...@@ -775,7 +801,8 @@ struct mlir_program ...@@ -775,7 +801,8 @@ struct mlir_program
mlir_module mmodule; mlir_module mmodule;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_arch; std::string target_arch = "";
std::size_t num_cu = 0;
std::string sym_name; std::string sym_name;
}; };
...@@ -832,7 +859,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs) ...@@ -832,7 +859,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
} }
} }
code_object_op compile_mlir(const context&, code_object_op compile_mlir(const context& migraphx_ctx,
module m, module m,
const std::vector<instruction_ref>& inputs, const std::vector<instruction_ref>& inputs,
const value& solution) const value& solution)
...@@ -844,7 +871,7 @@ code_object_op compile_mlir(const context&, ...@@ -844,7 +871,7 @@ code_object_op compile_mlir(const context&,
std::cout << m << std::endl; std::cout << m << std::endl;
mlir_program mp; mlir_program mp;
mp.find_target(); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
...@@ -871,12 +898,13 @@ instruction_ref insert_mlir(module& m, ...@@ -871,12 +898,13 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
tuning_config get_tuning_config_mlir(module m, const std::vector<shape>& inputs) tuning_config
get_tuning_config_mlir(const context& migraphx_ctx, module m, const std::vector<shape>& inputs)
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, inputs);
mlir_program mp; mlir_program mp;
mp.find_target(); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
return mp.get_tuning_config(); return mp.get_tuning_config();
} }
...@@ -903,10 +931,14 @@ instruction_ref ...@@ -903,10 +931,14 @@ instruction_ref
insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&) insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&)
{ {
use(co); use(co);
use(m);
return m.end(); return m.end();
} }
tuning_config get_tuning_config_mlir(module, const std::vector<shape>&) { return {}; } tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&)
{
return {};
}
// NOLINTEND(performance-unnecessary-value-param) // NOLINTEND(performance-unnecessary-value-param)
#endif #endif
......
...@@ -34,7 +34,7 @@ namespace gpu { ...@@ -34,7 +34,7 @@ namespace gpu {
std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0) std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0)
{ {
std::vector<argument> args; std::vector<argument> args;
std::transform(shapes.begin(), shapes.end(), std::back_inserter(args), [&](auto& s) { std::transform(shapes.begin(), shapes.end(), std::back_inserter(args), [&](const auto& s) {
return to_gpu(generate_argument(s, seed++)); return to_gpu(generate_argument(s, seed++));
}); });
return args; return args;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment