Commit 27c8f6ef authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir_gconv_msg

parents dbffeaff 6f1c947f
...@@ -34,16 +34,65 @@ namespace onnx { ...@@ -34,16 +34,65 @@ namespace onnx {
struct parse_slice : op_parser<parse_slice> struct parse_slice : op_parser<parse_slice>
{ {
std::vector<op_desc> operators() const { return {{"Slice"}}; } std::vector<op_desc> operators() const { return {{"Slice"}}; }
struct slice_desc
{
op::slice op;
std::vector<instruction_ref> op_args;
std::vector<int64_t> steps;
std::vector<int64_t> raxes;
void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); }
std::vector<int64_t> insert(instruction_ref arg)
{
std::vector<int64_t> result;
migraphx::argument arg_value = arg->eval();
if(arg_value.empty())
{
op_args.insert(op_args.begin(), arg);
}
else
{
arg_value.visit([&](auto s) { result.assign(s.begin(), s.end()); });
}
return result;
}
};
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
op::slice op; auto sd = construct_slice_desc(parser, info, args);
auto ins = info.add_instruction(sd.op, sd.op_args);
if(not sd.raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", sd.raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(sd.steps.begin(),
sd.steps.end(),
std::back_inserter(nsteps),
[](auto s) { return std::abs(s); });
return ins = info.add_instruction(
make_op("step", {{"axes", sd.op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
}
std::vector<int64_t> steps; slice_desc construct_slice_desc(const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
slice_desc sd;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice. // to decide whether MIGRAPHX can handle this slice.
...@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice> ...@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice>
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); step_arg.visit([&](auto s) { sd.steps.assign(s.begin(), s.end()); });
} }
if(args.size() >= 4) if(args.size() >= 4)
{ {
migraphx::argument axes_arg = args.at(3)->eval(); sd.op.axes = sd.insert(args.at(3));
check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "axes")) else if(contains(info.attributes, "axes"))
{ {
literal s = parser.parse_value(info.attributes.at("axes")); literal s = parser.parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.axes)); });
} }
if(args.size() >= 3) if(args.size() >= 3)
{ {
migraphx::argument end_arg = args.at(2)->eval(); sd.op.ends = sd.insert(args.at(2));
check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "ends")) else if(contains(info.attributes, "ends"))
{ {
literal s = parser.parse_value(info.attributes.at("ends")); literal s = parser.parse_value(info.attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.ends)); });
} }
if(args.size() >= 2) if(args.size() >= 2)
{ {
migraphx::argument start_arg = args.at(1)->eval(); sd.op.starts = sd.insert(args.at(1));
check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
} }
else if(contains(info.attributes, "starts")) else if(contains(info.attributes, "starts"))
{ {
literal s = parser.parse_value(info.attributes.at("starts")); literal s = parser.parse_value(info.attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.starts)); });
} }
// data input argument
sd.always_insert(args.at(0));
// If axes arg is not given, the default is all of them. // If axes arg is not given, the default is all of them.
if(op.axes.empty()) if(sd.op.axes.empty() and sd.op_args.size() < 3)
{ {
std::vector<int64_t> axes(args[0]->get_shape().ndim()); std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0}); std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes; sd.op.axes = axes;
} }
std::vector<int64_t> raxes; if(not sd.steps.empty())
{
if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported");
if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
}
assert(steps.empty() or steps.size() == op.axes.size()); assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op // If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size())) for(auto i : range(sd.steps.size()))
{ {
if(steps[i] >= 0) if(sd.steps[i] >= 0)
continue; continue;
op.starts[i] += 1; sd.op.starts[i] += 1;
if(op.starts[i] == 0) if(sd.op.starts[i] == 0)
op.starts[i] = INT_MAX; sd.op.starts[i] = INT_MAX;
op.ends[i] += 1; sd.op.ends[i] += 1;
raxes.push_back(op.axes[i]); sd.raxes.push_back(sd.op.axes[i]);
std::swap(op.starts[i], op.ends[i]); std::swap(sd.op.starts[i], sd.op.ends[i]);
}
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
} }
// If any steps are other than default 1, add a "steps" op return sd;
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) {
return std::abs(s);
});
return ins = info.add_instruction(
make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins);
}
else
return ins;
} }
}; };
......
...@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op ...@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
// Gather all the target roots // Gather all the target roots
std::unordered_multimap<std::size_t, module_ref> roots; std::unordered_multimap<std::size_t, module_ref> roots;
auto mods = this->get_modules(); auto mods = this->get_modules();
for(auto* mod : mods) for(const auto* mod : mods)
{ {
for(const auto& ins : *mod) for(const auto& ins : *mod)
{ {
...@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) { ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
auto& ctx = contexts[ins->get_target_id()]; const auto& ctx = contexts[ins->get_target_id()];
ctx.finish(); ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{}; timer t{};
...@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod, ...@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod,
std::back_inserter(module_inputs), std::back_inserter(module_inputs),
[&](const value& i) { return map_mods.at(i.to<std::string>()); }); [&](const value& i) { return map_mods.at(i.to<std::string>()); });
for(auto& smod : module_inputs) for(const auto& smod : module_inputs)
{ {
mod_from_val(smod, v, instructions, map_mods); mod_from_val(smod, v, instructions, map_mods);
} }
...@@ -1186,7 +1186,7 @@ void program::remove_unused_modules() ...@@ -1186,7 +1186,7 @@ void program::remove_unused_modules()
std::vector<module*> unused; std::vector<module*> unused;
generic_get_unused_modules( generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused)); impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused) for(const auto* m : unused)
this->remove_module(m->name()); this->remove_module(m->name());
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp) add_library(migraphx_py py_loader.cpp)
migraphx_generate_export_header(migraphx_py)
target_include_directories(migraphx_py PRIVATE include) target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx) target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include) rocm_install_targets(TARGETS migraphx_py INCLUDE include)
......
...@@ -26,11 +26,12 @@ ...@@ -26,11 +26,12 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
program load_py(const std::string& filename); MIGRAPHX_PY_EXPORT program load_py(const std::string& filename);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -64,7 +64,7 @@ static dynamic_loader py_lib() ...@@ -64,7 +64,7 @@ static dynamic_loader py_lib()
return lib; return lib;
} }
program load_py(const std::string& filename) MIGRAPHX_PY_EXPORT program load_py(const std::string& filename)
{ {
static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py"); static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py");
return f(filename); return f(filename);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
}); });
auto s = add_zero_point->get_shape(); auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant); auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
std::vector<int> max_data(s.elements(), max_quant); auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto min_arg = m.add_literal(literal(s, min_data)); auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction( m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
} }
......
...@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) ...@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
}; };
}; };
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2); return (dots >= 2 or convs >= 2 or qdots >= 2);
} }
struct find_conv_dot_horiz_fusion struct find_conv_dot_horiz_fusion
...@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto pred = [](auto i, auto j) { auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator()) if(i->get_operator() != j->get_operator())
return false; return false;
if(not contains({"dot", "convolution"}, i->name())) if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true; return true;
auto x = i->inputs()[1]->get_shape().lens(); auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens(); auto y = j->inputs()[1]->get_shape().lens();
...@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return false; return false;
// Check that non-axes match // Check that non-axes match
int axis = 1; int axis = 1;
if(i->name() == "dot") if(i->name() == "dot" or i->name() == "quant_dot")
{ {
axis = x.size() - 1; axis = x.size() - 1;
} }
...@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if(std::distance(start, last) < 2) if(std::distance(start, last) < 2)
return; return;
auto&& name = (*start)->name(); auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name)) if(not contains({"quant_dot", "dot", "convolution"}, name))
return; return;
auto op = (*start)->get_operator(); auto op = (*start)->get_operator();
int group = 1; int group = 1;
...@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); }); start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1; int axis = 1;
int concat_axis = 0; int concat_axis = 0;
if(name == "dot") if(name == "dot" or name == "quant_dot")
{ {
axis = int(args.front()->get_shape().lens().size() - 1); axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis; concat_axis = axis;
......
...@@ -48,6 +48,7 @@ struct sqlite_impl ...@@ -48,6 +48,7 @@ struct sqlite_impl
template <class F> template <class F>
void exec(const char* sql, F f) void exec(const char* sql, F f)
{ {
// cppcheck-suppress constParameterPointer
auto callback = [](void* obj, auto... xs) -> int { auto callback = [](void* obj, auto... xs) -> int {
try try
{ {
......
...@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot> ...@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)}; MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
} }
void required(const check_shapes& cs) const { cs.not_broadcasted(); } template <class T>
void required(const check_shapes<T>& cs) const
{
cs.not_broadcasted();
}
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
......
...@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive> ...@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
} }
// dnnl has some issues with non-packed inputs // dnnl has some issues with non-packed inputs
void required(const check_shapes& cs) const { cs.packed_or_broadcasted(); } template <class T>
void required(const check_shapes<T>& cs) const
{
cs.packed_or_broadcasted();
}
std::string name() const { return "dnnl::" + op.name(); } std::string name() const { return "dnnl::" + op.name(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -61,7 +61,7 @@ namespace cpu { ...@@ -61,7 +61,7 @@ namespace cpu {
std::string target::name() const { return "cpu"; } std::string target::name() const { return "cpu"; }
// cppcheck-suppress constParameter // cppcheck-suppress constParameterReference
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
......
...@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m)
{ {
if(starts_with(ins.name(), "@")) if(starts_with(ins.name(), "@"))
continue; continue;
if(ins.name() == "multibroadcast") if(contains({"multibroadcast", "contiguous"}, ins.name()))
continue; continue;
if(ins.name() == "pointwise") if(ins.name() == "pointwise")
{ {
......
...@@ -41,7 +41,7 @@ struct index ...@@ -41,7 +41,7 @@ struct index
__device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT __device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ index_int nlocal() const { return blockDim.x; } // NOLINT __device__ index_int nlocal() const { return blockDim.x; } // NOLINT
template <class F> template <class F>
__device__ void global_stride(index_int n, F f) const __device__ void global_stride(index_int n, F f) const
...@@ -81,6 +81,12 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) ...@@ -81,6 +81,12 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
dim3 nthreads(local); dim3 nthreads(local);
// cppcheck-suppress UseDeviceLaunch // cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f); hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f);
hipError_t kernel_launch_status = hipGetLastError();
if(kernel_launch_status != hipSuccess)
{
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " +
std::string(hipGetErrorString(kernel_launch_status)));
}
}; };
} }
......
...@@ -124,7 +124,7 @@ void nary_broadcast_vec_impl( ...@@ -124,7 +124,7 @@ void nary_broadcast_vec_impl(
buffer[i] = binput.data()[i]; buffer[i] = binput.data()[i];
} }
__syncthreads(); __syncthreads();
auto* bp = as_pointer(buffer); const auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
...@@ -219,7 +219,7 @@ void nary_double_broadcast_vec_impl( ...@@ -219,7 +219,7 @@ void nary_double_broadcast_vec_impl(
buffer[i + bdim_vec_len] = binput2.data()[i]; buffer[i + bdim_vec_len] = binput2.data()[i];
} }
__syncthreads(); __syncthreads();
auto* bp = as_pointer(buffer); const auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
......
...@@ -72,12 +72,12 @@ struct hip_heap_vector ...@@ -72,12 +72,12 @@ struct hip_heap_vector
index_int l = 2 * index + 1; index_int l = 2 * index + 1;
index_int r = 2 * index + 2; index_int r = 2 * index + 2;
if(l < n && compare(data[data_index(l)], data[data_index(index)])) if(l < n and compare(data[data_index(l)], data[data_index(index)]))
{ {
index = l; index = l;
} }
if(r < n && compare(data[data_index(r)], data[data_index(index)])) if(r < n and compare(data[data_index(r)], data[data_index(index)]))
{ {
index = r; index = r;
if(compare(data[data_index(l)], data[data_index(r)])) if(compare(data[data_index(l)], data[data_index(r)]))
......
...@@ -31,20 +31,6 @@ namespace migraphx { ...@@ -31,20 +31,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class HipDeviceProp>
std::string get_arch_name(rank<0>, const HipDeviceProp& props)
{
return "gfx" + std::to_string(props.gcnArch);
}
template <class HipDeviceProp>
auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName))
{
return std::string(props.gcnArchName);
}
std::string get_arch_name(const hipDeviceProp_t& props) { return get_arch_name(rank<1>{}, props); }
int get_device_id() int get_device_id()
{ {
int device; int device;
...@@ -60,7 +46,7 @@ std::string get_device_name() ...@@ -60,7 +46,7 @@ std::string get_device_name()
auto status = hipGetDeviceProperties(&props, get_device_id()); auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties"); MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(props); return props.gcnArchName;
} }
} // namespace gpu } // namespace gpu
......
...@@ -86,7 +86,7 @@ struct mlir_op ...@@ -86,7 +86,7 @@ struct mlir_op
size_t param_cnt = 0; size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names(); std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
for(std::string param_name : names) for(const std::string& param_name : names)
{ {
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++]; ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
} }
...@@ -214,42 +214,47 @@ struct find_mlir_op ...@@ -214,42 +214,47 @@ struct find_mlir_op
return false; return false;
} }
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"}; const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {"convolution", const std::initializer_list<std::string> no_bool_ops = {
"quant_convolution", "convolution",
"dot", "quant_convolution",
"quant_dot", "dot",
"add", "quant_dot",
"clip", "add",
"relu", "clip",
"sub", "relu",
"mul", "sub",
"div", "mul",
"pow", "div",
"where", "pow",
"quantizelinear", "where",
"dequantizelinear", "quantizelinear",
"abs", "dequantizelinear",
"neg"}; "abs",
const std::initializer_list<std::string> fp_only_ops = {"ceil", "neg",
"erf", };
"exp", const std::initializer_list<std::string> fp_only_ops = {
"floor", "ceil",
"log", "erf",
"recip", "exp",
"rsqrt", "floor",
"sigmoid" "log",
"softmax", "recip",
"tanh"}; "rsqrt",
// There are bugs in MLIR right now for models using sigmoid so disable it for now
// "sigmoid",
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type && contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true; return true;
if(is_float && contains(fp_only_ops, name)) if(is_float and contains(fp_only_ops, name))
return true; return true;
// Only conversions between floating types are known to be unambigiously // Only conversions between floating types are known to be unambigiously
// supported. // supported.
if(is_float && name == "convert") if(is_float and name == "convert")
{ {
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
......
...@@ -46,13 +46,7 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); ...@@ -46,13 +46,7 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
struct hip_device struct hip_device
{ {
hip_device() hip_device() : device_props{} { add_stream(); }
{
device_props.gcnArchName[0] = '\0';
device_props.gcnArch = 0;
device_props.multiProcessorCount = 0;
add_stream();
}
hip_device(std::size_t id, std::size_t n) : device_id(id) hip_device(std::size_t id, std::size_t n) : device_id(id)
{ {
...@@ -171,7 +165,7 @@ struct hip_device ...@@ -171,7 +165,7 @@ struct hip_device
std::size_t stream_id() const { return current_stream; } std::size_t stream_id() const { return current_stream; }
std::string get_device_name() const { return get_arch_name(device_props); } std::string get_device_name() const { return device_props.gcnArchName; }
std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); } std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); }
......
...@@ -33,8 +33,6 @@ namespace migraphx { ...@@ -33,8 +33,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT std::string get_arch_name(const hipDeviceProp_t& props);
MIGRAPHX_GPU_EXPORT std::string get_device_name(); MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
......
...@@ -92,7 +92,7 @@ struct hip_sync_stream ...@@ -92,7 +92,7 @@ struct hip_sync_stream
return inputs.front(); return inputs.front();
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(const context& ctx, const shape&, const std::vector<argument>& args) const
{ {
gpu_sync(ctx); gpu_sync(ctx);
if(args.empty()) if(args.empty())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment