Commit 64998903 authored by Chris Austen's avatar Chris Austen
Browse files

Merge remote-tracking branch 'origin/develop' into rel57_workitems

parents 2190957c e4dc75ea
...@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s) ...@@ -460,11 +460,11 @@ instruction_ref module::add_parameter(std::string name, shape s)
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
{ {
impl->push_back({builtin::returns{}, {}, std::move(args)}); shape instr_shape = compute_shape(builtin::returns{}, args);
impl->push_back({builtin::returns{}, instr_shape, std::move(args)});
auto result = std::prev(impl->instructions.end()); auto result = std::prev(impl->instructions.end());
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
......
...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_rank = x_lens.size(); auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2) if(x_rank == 1 or x_rank == 2)
{ {
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]); auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto mul0 = info.add_broadcastable_binary_op("mul", args[1], rsqrt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]); return info.add_broadcastable_binary_op("add", r0, args[2]);
} }
else if(x_rank > 2) else if(x_rank > 2)
...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2); std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction( auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction( auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
else else
......
...@@ -79,13 +79,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -79,13 +79,11 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens();
if(not contains(valid_types, dtype)) if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
bool dyn_input = x->get_shape().dynamic(); auto ndims = x->get_shape().ndim();
auto ndims = x->get_shape().ndim();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
...@@ -102,6 +100,12 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -102,6 +100,12 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
(dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean"; (dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
if(dtype == shape::half_type and not convert_fp16) if(dtype == shape::half_type and not convert_fp16)
{ {
if(x->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_INSTANCENORM: half type not supported with dynamic shape "
"unless convert_fp16 is TRUE");
}
auto dims = x->get_shape().lens();
double n = double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) { std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j; return i * j;
...@@ -122,13 +126,14 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -122,13 +126,14 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
// both scale and bias. // both scale and bias.
instruction_ref scale_bcast; instruction_ref scale_bcast;
instruction_ref bias_bcast; instruction_ref bias_bcast;
if(dyn_input) if(x->get_shape().dynamic())
{ {
scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x); scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x);
bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x); bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x);
} }
else else
{ {
auto dims = x->get_shape().lens();
scale_bcast = info.add_instruction( scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
bias_bcast = bias_bcast =
......
...@@ -30,8 +30,11 @@ namespace migraphx { ...@@ -30,8 +30,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
// Use a literal instruction to replace the shape since, output of /**
// shape operator are literals in migraphx * If static shape input, creates a literal in migraphx.
* If dynamic shape input, creates a dimensions_of operator in migraphx (runtime evaluation of
* shape).
*/
struct parse_shape : op_parser<parse_shape> struct parse_shape : op_parser<parse_shape>
{ {
std::vector<op_desc> operators() const { return {{"Shape"}}; } std::vector<op_desc> operators() const { return {{"Shape"}}; }
...@@ -43,13 +46,54 @@ struct parse_shape : op_parser<parse_shape> ...@@ -43,13 +46,54 @@ struct parse_shape : op_parser<parse_shape>
{ {
if(args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand"); MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens(); auto input_shape = args[0]->get_shape();
std::vector<int64_t> vec_shape(arg_shape.size()); int input_ndim = input_shape.ndim();
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()}); std::size_t start = 0;
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { std::size_t end = input_ndim;
return int64_t(i); // Normalizing the start and end is handled here because of how the static shape version
}); // works. Clamping to [-r, r], where r is ndim of input and then making positive.
return info.add_literal(migraphx::literal{s, vec_shape}); auto normalize_ind = [&](int64_t ind) {
if(ind < (-1 * input_ndim))
{
ind = -1 * input_ndim;
}
if(ind > input_ndim)
{
ind = input_ndim;
}
return (ind >= 0) ? ind : input_ndim + ind;
};
if(contains(info.attributes, "end"))
{
end = normalize_ind(info.attributes.at("end").i());
}
if(contains(info.attributes, "start"))
{
start = normalize_ind(info.attributes.at("start").i());
}
if(end <= start)
{
MIGRAPHX_THROW("PARSE_SHAPE: ending axis <= starting axis, end: " +
std::to_string(end) + " start: " + std::to_string(start));
}
if(input_shape.dynamic())
{
return info.add_instruction(make_op("dimensions_of", {{"start", start}, {"end", end}}),
args[0]);
}
else
{
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input_shape.lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
return info.add_literal(migraphx::literal{s, vec_shape});
}
} }
}; };
......
...@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes) ...@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return it->first; return it->first;
} }
std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
{
auto result = shapes;
auto perm = find_permutation(shapes);
std::transform(result.begin(), result.end(), result.begin(), [&](auto s) {
return reorder_shape(s, perm);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -23,14 +23,24 @@ ...@@ -23,14 +23,24 @@
##################################################################################### #####################################################################################
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp)
target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include)
if(MIGRAPHX_ENABLE_PYTHON) if(MIGRAPHX_ENABLE_PYTHON)
include(PythonModules) include(PythonModules)
add_custom_target(migraphx_py)
foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) foreach(PYTHON_VERSION ${PYTHON_VERSIONS})
py_add_module(migraphx_py_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx) py_add_module(migraphx_pybind_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) target_link_libraries(migraphx_pybind_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
rocm_install_targets(TARGETS migraphx_pybind_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_pybind_${PYTHON_VERSION})
add_library(migraphx_py_${PYTHON_VERSION} py.cpp)
target_include_directories(migraphx_py_${PYTHON_VERSION} PRIVATE include)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PUBLIC migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE pybind11::pybind11 python${PYTHON_VERSION}::runtime)
rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION}) rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION}) add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION})
endforeach() endforeach()
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
program load_py(const std::string& filename);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/file_buffer.hpp>
#include <pybind11/embed.h>
namespace py = pybind11;
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
// extern "C" is used to disable name mangling, but the function will still be called from C++
extern "C" program migraphx_load_py(const std::string& filename);
#ifdef __clang__
#pragma clang diagnostic pop
#endif
const std::string& python_path()
{
static const auto path = dynamic_loader::path(&migraphx_load_py).parent_path().string();
return path;
}
static py::dict run_file(const std::string& file)
{
py::object scope = py::module_::import("__main__").attr("__dict__");
std::string buffer;
buffer.append("import sys\n");
buffer.append("sys.path.insert(0, '" + python_path() + "')\n");
buffer.append("import migraphx\n");
buffer.append(read_string(file));
py::exec(buffer, scope);
return scope.cast<py::dict>();
}
extern "C" program migraphx_load_py(const std::string& filename)
{
py::scoped_interpreter guard{};
py::dict vars = run_file(filename);
auto it = std::find_if(vars.begin(), vars.end(), [](const auto& p) {
return py::isinstance<migraphx::program>(p.second);
});
if(it == vars.end())
MIGRAPHX_THROW("No program variable found");
return it->second.cast<migraphx::program>();
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/py.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static std::vector<fs::path> find_available_python_versions()
{
std::vector<fs::path> result;
auto path = dynamic_loader::path(&load_py).parent_path();
for(const auto& entry : fs::directory_iterator{path})
{
if(not entry.is_regular_file())
continue;
auto p = entry.path();
if(not contains(p.stem().string(), "migraphx_py_"))
continue;
result.push_back(p);
}
std::sort(result.begin(), result.end(), std::greater<>{});
return result;
}
static dynamic_loader load_py_lib()
{
auto libs = find_available_python_versions();
for(const auto& lib : libs)
{
auto result = dynamic_loader::try_load(lib);
if(result.has_value())
return *result;
}
MIGRAPHX_THROW("Cant find a viable version of python");
}
static dynamic_loader py_lib()
{
static dynamic_loader lib = load_py_lib();
return lib;
}
program load_py(const std::string& filename)
{
static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py");
return f(filename);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -48,19 +49,12 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) ...@@ -48,19 +49,12 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input // This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator. // from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it // For the conversion, there could be cases of overflowing or underflowing, but it
// is very rare in the area of deeping learning, so we just do a // is uncommon. Run optimize_module() before converting to fp16 to const eval and fold in FP32 to
// truncate of the input to get the fp16. // avoid loss of precision.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
run_passes(prog, run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}});
{quantize_fp16_pass{ins_names},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
} }
void quantize_int8(program& prog, void quantize_int8(program& prog,
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -273,9 +273,23 @@ shape shape::from_permutation(type_t t, ...@@ -273,9 +273,23 @@ shape shape::from_permutation(type_t t,
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<std::size_t>& shape::lens() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: lens() called on a dynamic shape");
}
return impl->m_lens;
}
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: strides() called on a dynamic shape");
}
return impl->m_strides;
}
std::size_t shape::ndim() const std::size_t shape::ndim() const
{ {
...@@ -535,7 +549,14 @@ bool shape::any_of_dynamic() const ...@@ -535,7 +549,14 @@ bool shape::any_of_dynamic() const
}); });
} }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; } const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const
{
if(not this->dynamic())
{
MIGRAPHX_THROW("SHAPE: dyn_dims() called on a static shape");
}
return impl->m_dyn_dims;
}
std::vector<std::size_t> shape::min_lens() const std::vector<std::size_t> shape::min_lens() const
{ {
...@@ -679,12 +700,22 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; } ...@@ -679,12 +700,22 @@ const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s) void migraphx_to_value(value& v, const shape& s)
{ {
value result; value result;
result["type"] = migraphx::to_value(s.type_string()); result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens()); result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
result["strides"] = migraphx::to_value(s.strides()); // avoid calling functions that will throw
result["sub_shapes"] = migraphx::to_value(s.sub_shapes()); if(s.dynamic())
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims()); {
v = result; result["lens"] = {};
result["strides"] = {};
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
}
else
{
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["dynamic_dimensions"] = {};
}
v = result;
} }
void migraphx_from_value(const value& v, shape& s) void migraphx_from_value(const value& v, shape& s)
......
...@@ -89,38 +89,23 @@ struct find_reshaper ...@@ -89,38 +89,23 @@ struct find_reshaper
{ {
auto matcher() const auto matcher() const
{ {
return match::name(reshaper_names())( auto reshaper = match::name(reshaper_names());
match::any_of[match::outputs()](match::name(reshaper_names()))); auto contiguous = match::name("contiguous");
auto no_output_reshape = match::none_of[match::outputs()](reshaper);
auto input_reshape = match::arg(0)(match::skip(contiguous)(reshaper));
auto input = match::skip(reshaper, contiguous)(match::any().bind("x"));
return reshaper(no_output_reshape, input_reshape, input);
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; auto input = mr.instructions["x"];
while(is_reshaper(reshapes.back())) auto dims = ins->get_shape().lens();
{
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()}; if(not input->get_shape().standard())
for(auto start : iterator_for(reshapes)) input = m.insert_instruction(ins, make_op("contiguous"), input);
{ m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
m.replace_instruction(r.first, r.second);
}
} }
}; };
...@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const ...@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
find_resize{}, find_resize{},
find_reshape_cont{},
find_nop_reshapes{}, find_nop_reshapes{},
find_reshaper{}, find_reshaper{},
find_reshape_cont{},
find_transpose{}, find_transpose{},
find_concat_transpose{}, find_concat_transpose{},
find_concat_multibroadcasts{}, find_concat_multibroadcasts{},
......
...@@ -33,7 +33,10 @@ if(NOT TARGET MIOpen) ...@@ -33,7 +33,10 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
find_package(composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED) if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif()
if(BUILD_DEV) if(BUILD_DEV)
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
...@@ -45,7 +48,7 @@ include(Embed) ...@@ -45,7 +48,7 @@ include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES}) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
...@@ -85,6 +88,12 @@ target_link_libraries(kernel_file_check compile_for_gpu) ...@@ -85,6 +88,12 @@ target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check) rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
endif()
add_library(migraphx_gpu add_library(migraphx_gpu
abs.cpp abs.cpp
analyze_streams.cpp analyze_streams.cpp
...@@ -133,6 +142,7 @@ add_library(migraphx_gpu ...@@ -133,6 +142,7 @@ add_library(migraphx_gpu
write_literals.cpp write_literals.cpp
${JIT_GPU_SRCS} ${JIT_GPU_SRCS}
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
migraphx_generate_export_header(migraphx_gpu) migraphx_generate_export_header(migraphx_gpu)
...@@ -236,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_ ...@@ -236,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API) if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) check_library_exists(MIOpen "miopenSetFindOptionPreallocatedTensor" "${MIOPEN_LOCATION}" HAS_PREALLOCATION_API)
if(HAS_PREALLOCATION_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API -DMIGRAPHX_PREALLOCATE_MIOPEN_BUFFERS)
else()
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
endif()
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen") message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else() else()
message(STATUS "MIGraphx is using legacy Find API in MIOpen") message(STATUS "MIGraphx is using legacy Find API in MIOpen")
...@@ -250,7 +265,11 @@ else() ...@@ -250,7 +265,11 @@ else()
endif() endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
endif()
add_subdirectory(driver) add_subdirectory(driver)
add_subdirectory(hiprtc) add_subdirectory(hiprtc)
......
...@@ -135,14 +135,13 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -135,14 +135,13 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
std::size_t num_elements = n; // hip require global workitems multiple of local workitems. It may degrade performance.
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
// https://reviews.llvm.org/D155213
std::size_t num_elements = ((n + local - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local; std::size_t groups = (num_elements + local - 1) / local;
std::size_t max_blocks = max_global / local; std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t nglobal = std::min(max_blocks * over, groups) * local;
#ifdef MIGRAPHX_USE_HIPRTC
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
num_elements = ((num_elements + local - 1) / local) * local;
#endif
return std::min(nglobal, num_elements); return std::min(nglobal, num_elements);
}; };
} }
...@@ -168,7 +167,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -168,7 +167,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
[](auto&& p) { [](auto&& p) {
auto&& name = p.first; auto&& name = p.first;
auto&& c = p.second; auto&& c = p.second;
auto path = fs::path{"migraphx"} / "kernels" / name; auto path = name;
return src_file{path, c}; return src_file{path, c};
}); });
srcs.push_back(src_file{fs::path{"main.cpp"}, srcs.push_back(src_file{fs::path{"main.cpp"},
......
...@@ -216,6 +216,7 @@ struct find_mlir_op ...@@ -216,6 +216,7 @@ struct find_mlir_op
"quant_dot", "quant_dot",
"add", "add",
"clip", "clip",
"relu",
"sub", "sub",
"mul", "mul",
"div", "div",
......
...@@ -140,8 +140,11 @@ void gemm_impl(context& ctx, ...@@ -140,8 +140,11 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
rocblas_gemm_flags flag = rocblas_gemm_flags flag = rocblas_gemm_flags_none;
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; #if ROCBLAS_VERSION_MAJOR < 3
if(int8_x4_format)
flag = rocblas_gemm_flags_pack_int8x4;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <functional> #include <functional>
namespace migraphx { namespace migraphx {
...@@ -68,12 +69,6 @@ struct compiler_replace ...@@ -68,12 +69,6 @@ struct compiler_replace
} }
}; };
struct tuning_config
{
value problem;
std::vector<value> solutions;
};
using compiler_compile = using compiler_compile =
std::function<compiler_replace(context&, instruction_ref, operation, const value&)>; std::function<compiler_replace(context&, instruction_ref, operation, const value&)>;
using compiler_compile_op = using compiler_compile_op =
......
...@@ -41,8 +41,6 @@ struct miopen_contiguous : unary_device<miopen_contiguous, &device::contiguous> ...@@ -41,8 +41,6 @@ struct miopen_contiguous : unary_device<miopen_contiguous, &device::contiguous>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
return {t, lens}; return {t, lens};
......
...@@ -160,10 +160,31 @@ struct miopen_convolution ...@@ -160,10 +160,31 @@ struct miopen_convolution
shape find(context& ctx, const shape& output_shape, const std::vector<shape>& inputs) shape find(context& ctx, const shape& output_shape, const std::vector<shape>& inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(reshape_if_1d(inputs[0]), int8_x4_format); auto x_desc = make_tensor(reshape_if_1d(inputs[0]), int8_x4_format);
auto w_desc = make_tensor(reshape_if_1d(inputs[1]), int8_x4_format); auto w_desc = make_tensor(reshape_if_1d(inputs[1]), int8_x4_format);
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto* miopen_stream_handle = ctx.get_stream().get_miopen();
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
auto status = miopenConvolutionForwardGetWorkSpaceSize(miopen_stream_handle,
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen" + op.name() + " : Failed to get forward workspace size");
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x_shape = inputs[0];
auto w_shape = inputs[1];
if(int8_x4_format)
{
x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape);
}
#ifdef MIGRAPHX_HAS_FIND_2_API #ifdef MIGRAPHX_HAS_FIND_2_API
{ {
auto conv_problem = make_obj<miopen_problem>( auto conv_problem = make_obj<miopen_problem>(
...@@ -171,13 +192,34 @@ struct miopen_convolution ...@@ -171,13 +192,34 @@ struct miopen_convolution
set_tensor_descriptor(miopenTensorConvolutionX, x_desc, conv_problem); set_tensor_descriptor(miopenTensorConvolutionX, x_desc, conv_problem);
set_tensor_descriptor(miopenTensorConvolutionW, w_desc, conv_problem); set_tensor_descriptor(miopenTensorConvolutionW, w_desc, conv_problem);
bool preallocate = false;
#ifdef MIGRAPHX_PREALLOCATE_MIOPEN_BUFFERS
// MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6
preallocate = true;
#endif
auto x = preallocate ? to_gpu(generate_argument(x_shape)) : inputs[0];
auto w = preallocate ? to_gpu(generate_argument(w_shape)) : inputs[1];
auto y = preallocate ? allocate_gpu(output_shape) : inputs[2];
auto workspace =
preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape);
set_tensor_descriptor(miopenTensorConvolutionY, y_desc, conv_problem); set_tensor_descriptor(miopenTensorConvolutionY, y_desc, conv_problem);
auto* miopen_stream_handle = ctx.get_stream().get_miopen(); const miopenTensorArgument_t tensor_args[3] = {
{miopenTensorConvolutionX, nullptr, x.implicit()},
{miopenTensorConvolutionW, nullptr, w.implicit()},
{miopenTensorConvolutionY, nullptr, y.implicit()},
};
solution_ptr = find_solution(miopen_stream_handle,
3,
tensor_args,
workspace.implicit(),
workspace_size,
conv_problem.get(),
ctx.get_exhaustive_tune_flag());
solution_ptr = find_solution( status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size);
miopen_stream_handle, conv_problem.get(), ctx.get_exhaustive_tune_flag());
auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size"); MIGRAPHX_THROW("MIOpen" + op.name() + " : failed to get solution's workspace size");
...@@ -196,29 +238,10 @@ struct miopen_convolution ...@@ -196,29 +238,10 @@ struct miopen_convolution
return shape{shape::int8_type, {workspace_size}}; return shape{shape::int8_type, {workspace_size}};
} }
#else #else
auto status = miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen" + op.name() + " : Failed to get forward workspace size");
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x_shape = inputs[0];
auto w_shape = inputs[1];
if(int8_x4_format)
{
x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape);
}
auto x = to_gpu(generate_argument(x_shape)); auto x = to_gpu(generate_argument(x_shape));
auto w = to_gpu(generate_argument(w_shape)); auto w = to_gpu(generate_argument(w_shape));
auto y = allocate_gpu(output_shape); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1; int algo_count = 1;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(), status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
...@@ -338,6 +361,7 @@ struct miopen_convolution ...@@ -338,6 +361,7 @@ struct miopen_convolution
return {s.type(), lens, strides}; return {s.type(), lens, strides};
} }
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -75,21 +75,43 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr ...@@ -75,21 +75,43 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr
using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem); using miopen_problem = MIGRAPHX_MANAGE_PTR(miopenProblem_t, miopenDestroyProblem);
using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution); using miopen_solution = MIGRAPHX_MANAGE_PTR(miopenSolution_t, miopenDestroySolution);
inline miopen_solution inline miopen_solution find_solution(miopenHandle_t handle,
find_solution(miopenHandle_t handle, miopenProblem_t problem, bool tune = false) size_t num_inputs,
const miopenTensorArgument_t* tensor_args,
void* workspace,
size_t workspace_size,
miopenProblem_t problem,
bool tune = false)
{ {
miopenSolution_t solution; miopenSolution_t solution;
size_t found = 0; size_t found = 0;
miopen_find_options fo = nullptr; miopen_find_options fo = make_obj<miopen_find_options>(&miopenCreateFindOptions);
if(tune) if(tune)
{ {
fo = make_obj<miopen_find_options>(&miopenCreateFindOptions);
miopenSetFindOptionTuning(fo.get(), 1); miopenSetFindOptionTuning(fo.get(), 1);
} }
auto status = miopenFindSolutions(handle, problem, fo.get(), &solution, &found, 1); #ifdef MIGRAPHX_PREALLOCATE_MIOPEN_BUFFERS
for(auto i : range(num_inputs))
{
auto status = miopenSetFindOptionPreallocatedTensor(
fo.get(), tensor_args[i].id, tensor_args[i].buffer);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen: failed to preallocate tensors for the find process");
}
auto status = miopenSetFindOptionPreallocatedWorkspace(fo.get(), workspace, workspace_size);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen: failed to preallocate workspace for the find process");
#else
miopenStatus_t status;
(void)(num_inputs);
(void)(tensor_args);
(void)(workspace_size);
(void)(workspace);
#endif
status = miopenFindSolutions(handle, problem, fo.get(), &solution, &found, 1);
auto result = miopen_solution{solution}; auto result = miopen_solution{solution};
if(status != miopenStatusSuccess or found == 0) if(status != miopenStatusSuccess or found == 0)
MIGRAPHX_THROW("MIOpen miopenFindSolutions failed"); MIGRAPHX_THROW("MIOpen: miopenFindSolutions failed");
return result; return result;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment