Commit be38aff9 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into bert-attention-no-transpose-ops

parents 74b947ed 86061b4d
...@@ -203,6 +203,8 @@ rocm_enable_cppcheck( ...@@ -203,6 +203,8 @@ rocm_enable_cppcheck(
useSmartPointer:*make_shared_array.hpp useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp constParameter:*src/targets/gpu/*.cpp
constParameter:*src/targets/gpu/*.hpp constParameter:*src/targets/gpu/*.hpp
# Suppress mlir_conv.cpp since this file will be deleted
*:*src/targets/gpu/mlir_conv.cpp
FORCE FORCE
INCONCLUSIVE INCONCLUSIVE
RULE_FILE RULE_FILE
......
...@@ -2,6 +2,6 @@ pfultz2/rocm-recipes ...@@ -2,6 +2,6 @@ pfultz2/rocm-recipes
facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake
ccache@v4.1 ccache@v4.1
pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11 pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11
danmar/cppcheck@2.6 -DHAVE_RULES=1 danmar/cppcheck@2.8 -DHAVE_RULES=1
RadeonOpenCompute/rocm-cmake@1ebf7e7bc61bb5e949c171562b421264065230a7 --build RadeonOpenCompute/rocm-cmake@1ebf7e7bc61bb5e949c171562b421264065230a7 --build
-f requirements.txt -f requirements.txt
...@@ -24,16 +24,16 @@ ...@@ -24,16 +24,16 @@
"import os.path\n", "import os.path\n",
"\n", "\n",
"if not os.path.exists(\"./utilities/coco.names\"):\n", "if not os.path.exists(\"./utilities/coco.names\"):\n",
" !wget https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/dependencies/coco.names -P ./utilities/\n", " !wget https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov4/dependencies/coco.names -P ./utilities/\n",
"if not os.path.exists(\"./utilities/yolov4_anchors.txt\"):\n", "if not os.path.exists(\"./utilities/yolov4_anchors.txt\"):\n",
" !wget https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/dependencies/yolov4_anchors.txt -P ./utilities/\n", " !wget https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov4/dependencies/yolov4_anchors.txt -P ./utilities/\n",
"if not os.path.exists(\"./utilities/input.jpg\"):\n", "if not os.path.exists(\"./utilities/input.jpg\"):\n",
" # The image used is from the COCO dataset (https://cocodataset.org/#explore)\n", " # The image used is from the COCO dataset (https://cocodataset.org/#explore)\n",
" # Other images can be tested by replacing the link below\n", " # Other images can be tested by replacing the link below\n",
" image_link = \"https://farm3.staticflickr.com/2009/2306189268_88cc86b30f_z.jpg\"\n", " image_link = \"https://farm3.staticflickr.com/2009/2306189268_88cc86b30f_z.jpg\"\n",
" !wget -O ./utilities/input.jpg $image_link\n", " !wget -O ./utilities/input.jpg $image_link\n",
"if not os.path.exists(\"./utilities/yolov4.onnx\"):\n", "if not os.path.exists(\"./utilities/yolov4.onnx\"):\n",
" !wget https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/yolov4/model/yolov4.onnx -P ./utilities/" " !wget https://github.com/onnx/models/raw/main/vision/object_detection_segmentation/yolov4/model/yolov4.onnx -P ./utilities/"
] ]
}, },
{ {
......
...@@ -39,10 +39,7 @@ template <class T, class F, class... Ts> ...@@ -39,10 +39,7 @@ template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs) T* make(F f, Ts&&... xs)
{ {
T* result = nullptr; T* result = nullptr;
// cppcheck-suppress redundantInitialization auto e = f(&result, std::forward<Ts>(xs)...);
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto e = f(&result, std::forward<Ts>(xs)...);
if(e != migraphx_status_success) if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function"); throw std::runtime_error("Failed to call function");
return result; return result;
...@@ -51,9 +48,6 @@ T* make(F f, Ts&&... xs) ...@@ -51,9 +48,6 @@ T* make(F f, Ts&&... xs)
template <class F, class... Ts> template <class F, class... Ts>
void call(F f, Ts&&... xs) void call(F f, Ts&&... xs)
{ {
// cppcheck-suppress redundantInitialization
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto e = f(std::forward<Ts>(xs)...); auto e = f(std::forward<Ts>(xs)...);
if(e != migraphx_status_success) if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function"); throw std::runtime_error("Failed to call function");
...@@ -340,7 +334,6 @@ struct interface_base : Base ...@@ -340,7 +334,6 @@ struct interface_base : Base
template <class T, class Setter, class F> template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f) void set_auto_fp(Setter setter, F f)
{ {
// cppcheck-suppress constParameter
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) { return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) {
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...); auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...);
}); });
......
...@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d) ...@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d)
// Collect all shapes // Collect all shapes
std::unordered_map<std::size_t, shape> shapes; std::unordered_map<std::size_t, shape> shapes;
{ {
// cppcheck-suppress variableScope
std::size_t i = 0; std::size_t i = 0;
fix([&](auto self, auto ss) { fix([&](auto self, auto ss) {
if(ss.sub_shapes().empty()) if(ss.sub_shapes().empty())
...@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d) ...@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d)
} }
assert(offset == s.bytes()); assert(offset == s.bytes());
// cppcheck-suppress variableScope
std::size_t i = 0; std::size_t i = 0;
m_data = fix<data_t>([&](auto self, auto ss) { m_data = fix<data_t>([&](auto self, auto ss) {
data_t result; data_t result;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -71,6 +72,8 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -71,6 +72,8 @@ static bool try_compute_shape(instruction_ref ins,
void eliminate_contiguous::apply(module& m) const void eliminate_contiguous::apply(module& m) const
{ {
std::vector<instruction_ref> const_instruction;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
// return instruction should have inputs with standard shape // return instruction should have inputs with standard shape
...@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& m) const ...@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& m) const
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args; auto new_args = args;
auto mod_args = ins->module_inputs(); auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
{ {
if(arg->name() == op_name) if(arg->name() == op_name)
...@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& m) const ...@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& m) const
} }
else if(prev->can_eval()) else if(prev->can_eval())
{ {
auto c = op::contiguous{}; const_instruction.push_back(arg);
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = m.add_literal(r.get_shape(), r.data());
m.replace_instruction(arg, l);
} }
} }
} }
} }
// Perform evaluations in parallel
std::vector<argument> literals(const_instruction.size());
par_for(const_instruction.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instruction[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
});
for(size_t i = 0; i < const_instruction.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instruction[i], l);
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -207,8 +207,7 @@ auto visit_all_pack(const shape& s, V1&& v1) ...@@ -207,8 +207,7 @@ auto visit_all_pack(const shape& s, V1&& v1)
template <class T, class... Ts> template <class T, class... Ts>
auto visit_all(T&& x, Ts&&... xs) auto visit_all(T&& x, Ts&&... xs)
{ {
auto&& s = x.get_shape(); auto&& s = x.get_shape();
// cppcheck-suppress redundantInitialization
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...}; std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); })) if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
......
...@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{}) ...@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
value result = value::array{}; value result = value::array{};
for(auto&& y : x) for(auto&& y : x)
{ {
auto e = to_value(y);
result.insert(to_value(y)); result.insert(to_value(y));
} }
return result; return result;
......
...@@ -120,10 +120,8 @@ struct tensor_view ...@@ -120,10 +120,8 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
// cppcheck-suppress functionConst
iterator begin() { return {0, {this}}; } iterator begin() { return {0, {this}}; }
// cppcheck-suppress functionConst
iterator end() { return {this->size(), {this}}; } iterator end() { return {this->size(), {this}}; }
const_iterator begin() const { return {0, {this}}; } const_iterator begin() const { return {0, {this}}; }
......
...@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out ...@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
// cppcheck-suppress uninitvar
if(out_error != nullptr) if(out_error != nullptr)
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
......
...@@ -729,7 +729,6 @@ std::unordered_map<instruction_ref, std::string> ...@@ -729,7 +729,6 @@ std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const
{ {
os << "migraphx::module p;" << std::endl; os << "migraphx::module p;" << std::endl;
// cppcheck-suppress variableScope
unsigned long seed = 0; unsigned long seed = 0;
names = this->print( names = this->print(
[&](auto ins, auto ins_names) { [&](auto ins, auto ins_names) {
......
...@@ -128,7 +128,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -128,7 +128,7 @@ struct parse_pooling : op_parser<parse_pooling>
std::fill_n(values["stride"].begin(), kdims, 1); std::fill_n(values["stride"].begin(), kdims, 1);
} }
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding(paddings.begin(), paddings.end()); std::vector<int64_t> orig_padding = paddings;
std::vector<int64_t> slice_start; std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end; std::vector<int64_t> slice_end;
......
...@@ -30,11 +30,11 @@ struct parse_squeeze : op_parser<parse_squeeze> ...@@ -30,11 +30,11 @@ struct parse_squeeze : op_parser<parse_squeeze>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto op = parser.load(opd.op_name, info); auto op = parser.load(opd.op_name, info);
std::vector<int64_t> axes;
if(args.size() == 2) if(args.size() == 2)
{ {
auto arg_axes = args.at(1)->eval(); auto arg_axes = args.at(1)->eval();
check_arg_empty(arg_axes, "PARSE_" + opd.op_name + ": cannot handle variable axes!"); check_arg_empty(arg_axes, "PARSE_" + opd.op_name + ": cannot handle variable axes!");
std::vector<int64_t> axes;
arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); }); arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); });
op = assign_axes(op, axes); op = assign_axes(op, axes);
} }
......
...@@ -20,7 +20,6 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out ...@@ -20,7 +20,6 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
int ec = 0; int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl; std::cout << cmd << std::endl;
std::array<char, 128> buffer;
auto closer = [&](FILE* stream) { auto closer = [&](FILE* stream) {
auto status = pclose(stream); auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT
...@@ -30,6 +29,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out ...@@ -30,6 +29,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe) if(!pipe)
MIGRAPHX_THROW("popen() failed: " + cmd); MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data()); std_out(buffer.data());
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
...@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins) ...@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const void propagate_constant::apply(module& m) const
{ {
std::unordered_set<instruction_ref> const_instrs;
auto last = std::prev(m.end());
// Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m)) for(auto i : iterator_for(m))
{ {
if(i->name() != "@literal") if(is_const(i) and i != last)
continue; continue;
if(i->outputs().empty())
continue; std::copy_if(
fix([&](auto self, auto ins) { i->inputs().begin(),
std::unordered_set<instruction_ref> children(ins->outputs().begin(), i->inputs().end(),
ins->outputs().end()); std::inserter(const_instrs, const_instrs.begin()),
for(auto child : children) [&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; });
{ }
if(child->name() == "@literal" or skip_propogate(child))
{ // Compute literals in parallel
self(child); std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
continue; std::vector<argument> literals(const_instrs_vec.size());
} par_for(const_instrs_vec.size(), 1, [&](const auto i) {
auto r = child->eval(); literals[i] = const_instrs_vec[i]->eval();
if(not r.empty()) });
{
assert(r.get_shape() == child->get_shape()); // Replace instructions in m
auto l = m.add_literal(r.get_shape(), r.data()); for(size_t i = 0; i < const_instrs_vec.size(); i++)
self(m.replace_instruction(child, l)); {
} if(not literals[i].empty())
} {
})(i); assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
}
} }
} }
......
...@@ -20,7 +20,6 @@ struct cpu_copy : reduce_dims_base, auto_register_op<cpu_copy> ...@@ -20,7 +20,6 @@ struct cpu_copy : reduce_dims_base, auto_register_op<cpu_copy>
return inputs.at(1); return inputs.at(1);
} }
argument argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
argument result = get_arg(args, args.size() - 1); argument result = get_arg(args, args.size() - 1);
......
...@@ -26,7 +26,6 @@ struct cpu_gather : auto_register_op<cpu_gather> ...@@ -26,7 +26,6 @@ struct cpu_gather : auto_register_op<cpu_gather>
} }
argument argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
......
...@@ -323,7 +323,6 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>> ...@@ -323,7 +323,6 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>>
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
argument argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
argument result = get_arg(args, args.size() - 1); argument result = get_arg(args, args.size() - 1);
...@@ -362,7 +361,6 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>> ...@@ -362,7 +361,6 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>>
} }
argument argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
argument result = get_arg(args, args.size() - 1); argument result = get_arg(args, args.size() - 1);
......
...@@ -134,7 +134,6 @@ struct hiprtc_program ...@@ -134,7 +134,6 @@ struct hiprtc_program
std::vector<char> buffer(n); std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() == 0); assert(buffer.back() == 0);
// cppcheck-suppress returnDanglingLifetime
return {buffer.begin(), buffer.end() - 1}; return {buffer.begin(), buffer.end() - 1};
} }
......
...@@ -681,7 +681,7 @@ struct miopen_fusion ...@@ -681,7 +681,7 @@ struct miopen_fusion
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
fusion f = {}; fusion fp = {};
fusion::op_t conv = {}; fusion::op_t conv = {};
fusion::op_t bias = {}; fusion::op_t bias = {};
...@@ -705,19 +705,19 @@ struct miopen_conv_bias ...@@ -705,19 +705,19 @@ struct miopen_conv_bias
float beta = 0; float beta = 0;
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
return f.execute(ctx, fargs, args[0], args[4]); return fp.execute(ctx, fargs, args[0], args[4]);
} }
void finalize(context& ctx, const shape&, const std::vector<shape>& inputs) void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
{ {
f = fusion(inputs[0]); fp = fusion(inputs[0]);
conv = f.create_conv(op, inputs[1]); conv = fp.create_conv(op, inputs[1]);
bias = f.create_bias(inputs[3]); bias = fp.create_bias(inputs[3]);
if(not f.compile(ctx)) if(not fp.compile(ctx))
MIGRAPHX_THROW("Failed to compile fusion plan"); MIGRAPHX_THROW("Failed to compile fusion plan");
} }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return fp.get_workspace(ctx); }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
...@@ -728,7 +728,7 @@ MIGRAPHX_REGISTER_OP(miopen_conv_bias) ...@@ -728,7 +728,7 @@ MIGRAPHX_REGISTER_OP(miopen_conv_bias)
struct miopen_conv_bias_relu struct miopen_conv_bias_relu
{ {
op::convolution op; op::convolution op;
fusion f = {}; fusion fp = {};
fusion::op_t conv = {}; fusion::op_t conv = {};
fusion::op_t bias = {}; fusion::op_t bias = {};
fusion::op_t relu = {}; fusion::op_t relu = {};
...@@ -754,18 +754,18 @@ struct miopen_conv_bias_relu ...@@ -754,18 +754,18 @@ struct miopen_conv_bias_relu
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0); miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
return f.execute(ctx, fargs, args[0], args[4]); return fp.execute(ctx, fargs, args[0], args[4]);
} }
void finalize(context& ctx, const shape&, const std::vector<shape>& inputs) void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
{ {
f = fusion(inputs[0]); fp = fusion(inputs[0]);
conv = f.create_conv(op, inputs[1]); conv = fp.create_conv(op, inputs[1]);
bias = f.create_bias(inputs[3]); bias = fp.create_bias(inputs[3]);
relu = f.create_relu(); relu = fp.create_relu();
f.compile(ctx); fp.compile(ctx);
} }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return fp.get_workspace(ctx); }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
...@@ -875,7 +875,6 @@ struct find_conv_pointwise ...@@ -875,7 +875,6 @@ struct find_conv_pointwise
{ {
if(i.name()[0] == '@') if(i.name()[0] == '@')
continue; continue;
auto inputs = to_shapes(i.inputs());
op.ops.push_back({{i.get_operator()}}); op.ops.push_back({{i.get_operator()}});
} }
std::vector<instruction_ref> inputs = {input_ins, weights_ins, bias_ins, alloc_ins}; std::vector<instruction_ref> inputs = {input_ins, weights_ins, bias_ins, alloc_ins};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment