"configs/datasets/vscode:/vscode.git/clone" did not exist on "d2c40e564855c3e2a24bfdef04f8b3c9304c5e7b"
Commit 309df0d2 authored by wsttiger's avatar wsttiger
Browse files

Merge branch 'master' into concat

parents 3de56715 76f7ae49
...@@ -93,8 +93,11 @@ rocm_enable_cppcheck( ...@@ -93,8 +93,11 @@ rocm_enable_cppcheck(
noExplicitConstructor noExplicitConstructor
passedByValue passedByValue
unusedStructMember unusedStructMember
functionStatic
functionConst
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
FORCE FORCE
INCONCLUSIVE
RULE_FILE RULE_FILE
${CMAKE_CURRENT_SOURCE_DIR}/cppcheck.rules ${CMAKE_CURRENT_SOURCE_DIR}/cppcheck.rules
SOURCES SOURCES
......
...@@ -65,6 +65,9 @@ ADD dev-requirements.txt /dev-requirements.txt ...@@ -65,6 +65,9 @@ ADD dev-requirements.txt /dev-requirements.txt
ADD requirements.txt /requirements.txt ADD requirements.txt /requirements.txt
RUN cget -p $PREFIX install -f /dev-requirements.txt -DMIOPEN_CACHE_DIR="" RUN cget -p $PREFIX install -f /dev-requirements.txt -DMIOPEN_CACHE_DIR=""
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV LD_LIBRARY_PATH=$PREFIX/lib ENV LD_LIBRARY_PATH=$PREFIX/lib
# Install doc requirements # Install doc requirements
......
...@@ -3,7 +3,7 @@ pcre ...@@ -3,7 +3,7 @@ pcre
danmar/cppcheck@f965e5873 -DHAVE_RULES=1 danmar/cppcheck@f965e5873 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@3a41f286203968421c557338d6fb39c36f3c717c ROCm-Developer-Tools/HIP@3a41f286203968421c557338d6fb39c36f3c717c
# Needed for clang-ocl # Needed for clang-ocl
RadeonOpenCompute/rocm-cmake@d82a77c --build RadeonOpenCompute/rocm-cmake@6240bb3 --build
RadeonOpenCompute/clang-ocl@799713643b5591a3b877c586ef2c7fbc012af819 RadeonOpenCompute/clang-ocl@799713643b5591a3b877c586ef2c7fbc012af819
# python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 # python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
add_library(migraph add_library(migraph
auto_contiguous.cpp auto_contiguous.cpp
constant_propagate.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
...@@ -10,6 +11,7 @@ add_library(migraph ...@@ -10,6 +11,7 @@ add_library(migraph
instruction.cpp instruction.cpp
program.cpp program.cpp
shape.cpp shape.cpp
simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
......
#include <migraph/constant_propagate.hpp>
#include <migraph/program.hpp>
#include <migraph/matcher.hpp>
#include <migraph/literal.hpp>
namespace migraph {
struct match_const_add
{
auto matcher() const
{
return match::name("add")(match::args(match::name("@literal"), match::name("@literal")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto arg1 = ins->inputs().at(0)->get_literal();
auto arg2 = ins->inputs().at(1)->get_literal();
auto sum = p.add_literal(transform(arg1, arg2, [](auto x, auto y) { return x + y; }));
p.replace_instruction(ins, sum);
}
};
void constant_propagate::apply(program& p) const { match::find_matches(p, match_const_add{}); }
} // namespace migraph
...@@ -8,13 +8,19 @@ namespace migraph { ...@@ -8,13 +8,19 @@ namespace migraph {
struct check_shapes struct check_shapes
{ {
const std::vector<shape>* shapes; const shape* begin;
const shape* end;
const std::string name; const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {} check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
{
}
check_shapes(const std::vector<shape>& s) : begin(s.data()), end(s.data() + s.size()) {}
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name()) check_shapes(const std::vector<shape>& s, const Op& op)
: begin(s.data()), end(s.data() + s.size()), name(op.name())
{ {
} }
...@@ -26,21 +32,30 @@ struct check_shapes ...@@ -26,21 +32,30 @@ struct check_shapes
return name + ": "; return name + ": ";
} }
std::size_t size() const
{
if(begin == end)
return 0;
assert(begin != nullptr);
assert(end != nullptr);
return end - begin;
}
const check_shapes& has(std::size_t n) const const check_shapes& has(std::size_t n) const
{ {
assert(shapes != nullptr); if(size() != n)
if(shapes->size() != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) + MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size())); " but given " + std::to_string(size()));
return *this; return *this;
} }
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(shapes != nullptr); assert(begin != nullptr);
if(!shapes->empty()) assert(end != nullptr);
if(begin != end)
{ {
if(shapes->front().lens().size() != n) if(begin->lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
...@@ -105,19 +120,38 @@ struct check_shapes ...@@ -105,19 +120,38 @@ struct check_shapes
template <class F> template <class F>
bool same(F f) const bool same(F f) const
{ {
assert(shapes != nullptr); if(begin == end)
if(shapes->empty())
return true; return true;
auto&& key = f(shapes->front()); assert(begin != nullptr);
assert(end != nullptr);
auto&& key = f(*begin);
return this->all_of([&](const shape& s) { return f(s) == key; }); return this->all_of([&](const shape& s) { return f(s) == key; });
} }
template <class Predicate> template <class Predicate>
bool all_of(Predicate p) const bool all_of(Predicate p) const
{ {
assert(shapes != nullptr); if(begin == end)
return std::all_of(shapes->begin(), shapes->end(), p); return true;
assert(begin != nullptr);
assert(end != nullptr);
return std::all_of(begin, end, p);
}
const shape* get(long i)
{
if(i >= size())
MIGRAPH_THROW(prefix() + "Accessing shape out of bounds");
assert(begin != nullptr);
assert(end != nullptr);
if(i < 0)
return end - i;
return begin + i;
} }
check_shapes slice(long start) { return {get(start), end, name}; }
check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
}; };
} // namespace migraph } // namespace migraph
......
#ifndef MIGRAPH_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP
#define MIGRAPH_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP
#include <string>
namespace migraph {
struct program;
struct constant_propagate
{
std::string name() const { return "constant_propagate"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -108,6 +108,20 @@ literal transform(literal l, F f) ...@@ -108,6 +108,20 @@ literal transform(literal l, F f)
return result; return result;
} }
template <class F>
literal transform(literal l1, literal l2, F f)
{
assert(l1.get_shape() == l2.get_shape());
literal result;
visit_all(l1, l2)([&](auto x, auto y) {
using type = std::remove_cv_t<typename decltype(x)::value_type>;
std::vector<type> output(x.size(), 0.0);
std::transform(x.begin(), x.end(), y.begin(), output.begin(), f);
result = literal{l1.get_shape(), output};
});
return result;
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -76,7 +76,7 @@ struct bindable_matcher ...@@ -76,7 +76,7 @@ struct bindable_matcher
{ {
M m; M m;
auto bind(std::string name) { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{ {
...@@ -137,7 +137,7 @@ struct basic_matcher ...@@ -137,7 +137,7 @@ struct basic_matcher
}); });
} }
auto bind(std::string name) { return bind_match(m, name); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{ {
...@@ -176,12 +176,13 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) ...@@ -176,12 +176,13 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
inline instruction_ref name##_m::match(__VA_ARGS__) const inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher /// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...) \ #define MIGRAPH_PRED_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
{ \ { \
bool operator()(__VA_ARGS__) const; \ bool operator()(__VA_ARGS__) const; \
}; \ }; \
const constexpr auto name = migraph::match::basic_matcher<predicate_matcher<name##_m>>{{}}; \ const constexpr auto name = \
migraph::match::basic_matcher<migraph::match::predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const inline bool name##_m::operator()(__VA_ARGS__) const
struct matcher_result struct matcher_result
...@@ -263,7 +264,29 @@ auto any_of(Ts... ms) ...@@ -263,7 +264,29 @@ auto any_of(Ts... ms)
}); });
} }
MIGRAPH_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPH_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
}
MIGRAPH_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins->outputs().front();
return ctx.not_found();
}
MIGRAPH_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins;
if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
return ins;
return ctx.not_found();
}
inline auto name(std::string name) inline auto name(std::string name)
{ {
...@@ -306,6 +329,14 @@ auto args(Ms... ms) ...@@ -306,6 +329,14 @@ auto args(Ms... ms)
}); });
} }
inline auto either_arg(std::size_t i, std::size_t j)
{
return [=](auto m1, auto m2) {
return match::any_of(match::all_of(arg(i)(m1), arg(j)(m2)),
match::all_of(arg(j)(m1), arg(i)(m2)));
};
}
} // namespace match } // namespace match
} // namespace migraph } // namespace migraph
......
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string>
namespace migraph {
struct program;
struct simplify_algebra
{
std::string name() const { return "simplify_algebra"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -199,12 +199,10 @@ void memory_coloring_impl::register_operand_alias() ...@@ -199,12 +199,10 @@ void memory_coloring_impl::register_operand_alias()
void memory_coloring_impl::rewrite() void memory_coloring_impl::rewrite()
{ {
instruction_ref end = p_program->end();
instruction_ref scratch_param = end;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
dims.push_back(required_bytes / sizeof(float)); dims.push_back(required_bytes / sizeof(float));
shape s = {shape::float_type, dims}; shape s = {shape::float_type, dims};
scratch_param = p_program->add_parameter("scratch", s); instruction_ref scratch_param = p_program->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_program)) for(auto ins : iterator_for(*p_program))
{ {
const instruction* p_iter = &(*ins); const instruction* p_iter = &(*ins);
......
#include <migraph/simplify_algebra.hpp>
#include <migraph/program.hpp>
#include <migraph/operators.hpp>
#include <migraph/matcher.hpp>
#include <migraph/literal.hpp>
namespace migraph {
struct find_add_lit_broadcast
{
auto lit_broadcast() const
{
return match::any_of(match::name("@literal"), match::name("broadcast"));
}
auto not_lit_broadcast() const
{
return match::none_of(match::name("@literal"), match::name("broadcast"));
}
auto add_lit_broadcast(std::string x, std::string y) const
{
return match::name("add")(match::either_arg(0, 1)(lit_broadcast().bind(std::move(x)),
not_lit_broadcast().bind(std::move(y))));
}
auto matcher() const
{
return match::name("add")(
match::args(add_lit_broadcast("a", "x"), add_lit_broadcast("b", "y")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
if(a_ins->name() != b_ins->name())
return;
instruction_ref sumab;
if(a_ins->name() == "broadcast")
{
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
auto op = a_ins->get_operator();
auto presum =
p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = p.insert_instruction(ins, op, presum);
}
else
{
sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
}
auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins);
p.replace_instruction(ins, op::add{}, sumxy, sumab);
}
};
void simplify_algebra::apply(program& p) const { match::find_matches(p, find_add_lit_broadcast{}); }
} // namespace migraph
...@@ -10,6 +10,11 @@ void add(const argument& result, const argument& arg1, const argument& arg2) ...@@ -10,6 +10,11 @@ void add(const argument& result, const argument& arg1, const argument& arg2)
nary(result, arg1, arg2)([](auto x, auto y) { return x + y; }); nary(result, arg1, arg2)([](auto x, auto y) { return x + y; });
} }
void add(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3)
{
nary(result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; });
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
...@@ -10,6 +10,15 @@ void add_relu(const argument& result, const argument& arg1, const argument& arg2 ...@@ -10,6 +10,15 @@ void add_relu(const argument& result, const argument& arg1, const argument& arg2
nary(result, arg1, arg2)([](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); }); nary(result, arg1, arg2)([](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); });
} }
void add_relu(const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return std::max<decltype(x + y + z)>(0, x + y + z); });
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
...@@ -51,6 +51,108 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -51,6 +51,108 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
}); });
} }
template <class F>
void trinary_broadcast_vec_impl(
F f, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = as_vec4(input1.data());
auto* yp = as_vec4(input2.data());
auto* zp = as_vec4(input3.data());
auto* outp = as_vec4(output.data());
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> y = yp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(x[j], y[j], b);
}
outp[i] = out;
}
});
});
}
template <class F>
void trinary_broadcast_impl(
F f, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data();
auto* yp = input2.data();
auto* zp = input3.data();
auto* outp = output.data();
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = zp[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
type x = xp[i];
type y = yp[i];
outp[i] = f(x, y, b);
}
});
});
}
template <class F> template <class F>
void binary_broadcast_vec_impl(F f, void binary_broadcast_vec_impl(F f,
const argument& result, const argument& result,
...@@ -247,6 +349,36 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a ...@@ -247,6 +349,36 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a
}; };
} }
inline auto
nary(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3)
{
return [=](auto f) {
// TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().standard() and
arg3.get_shape().broadcasted())
{
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg3.get_shape().strides();
auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
auto b_idx = std::distance(strides.begin(), b_it);
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg3.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
trinary_broadcast_vec_impl(f, result, arg1, arg2, arg3);
else
trinary_broadcast_impl(f, result, arg1, arg2, arg3);
return;
}
}
nary_impl(f, result, arg1, arg2, arg3);
};
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
......
#include <migraph/gpu/fuse_ops.hpp> #include <migraph/gpu/fuse_ops.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/matcher.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/convolution.hpp>
#include <migraph/gpu/device/add_relu.hpp> #include <migraph/gpu/device/add_relu.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
...@@ -7,6 +9,171 @@ namespace migraph { ...@@ -7,6 +9,171 @@ namespace migraph {
namespace gpu { namespace gpu {
struct fusion
{
using op_t = miopenFusionOpDescriptor_t;
shared<fusion_plan_descriptor> fp;
// Used as a temporary hack to keep descriptor references alive
std::vector<std::shared_ptr<void>> storage;
template <class T>
auto keep_alive(T x)
{
auto result = share(std::move(x));
storage.push_back(result);
return result;
}
fusion(const shape& input)
// : fp(make_fusion_plan(input))
{
auto t = make_tensor(input);
fp = make_fusion_plan(t);
keep_alive(std::move(t));
}
op_t operator[](std::size_t i) const
{
op_t result;
auto status = miopenFusionPlanGetOp(fp.get(), i, &result);
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Failed retrieving operator at " + std::to_string(i));
return result;
}
auto get() const { return fp.get(); }
op_t create_bias(const shape& bias)
{
op_t result;
auto b = shape{bias.type(), {1, bias.lens().at(1), 1, 1}};
auto t = keep_alive(make_tensor(b));
auto status = miopenCreateOpBiasForward(fp.get(), &result, t.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Creating operator failed");
return result;
}
op_t create_relu()
{
op_t result;
auto status = miopenCreateOpActivationForward(fp.get(), &result, miopenActivationRELU);
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Creating operator failed");
return result;
}
op_t create_conv(const op::convolution& op, const shape& weights)
{
op_t result;
auto cd = keep_alive(make_conv(op));
auto t = keep_alive(make_tensor(weights));
auto status = miopenCreateOpConvForward(fp.get(), &result, cd.get(), t.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Creating operator failed");
return result;
}
shape get_workspace(context&)
{
// TODO: Use zero workspace for now
std::size_t ws_size = 0;
// int algo_count = 1;
// miopenConvFwdAlgorithm_t algo;
// miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
// miopenFusionPlanGetWorkSpaceSize(ctx.handle.get(), fp.get(), &ws_size, algo);
return shape{shape::int8_type, {ws_size}};
}
void compile(context& ctx)
{
auto status = miopenCompileFusionPlan(ctx.handle.get(), fp.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Compiling fusion plan failed");
}
argument execute(context& ctx,
const fused_operator_args& fargs,
const argument& x,
const argument& y) const
{
auto x_td = make_tensor(x.get_shape());
auto y_td = make_tensor(y.get_shape());
auto status = miopenExecuteFusionPlan(ctx.handle.get(),
fp.get(),
x_td.get(),
x.implicit(),
y_td.get(),
y.implicit(),
fargs.get());
if(status != miopenStatusSuccess)
MIGRAPH_THROW("Failed to execute fusion plan");
return y;
}
};
MIGRAPH_PRED_MATCHER(bias_shape, instruction_ref ins)
{
auto&& s = ins->get_shape();
return s.broadcasted() and s.strides().size() == 4 and s.strides()[0] == 0 and
s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
}
// TODO: Move to another header
template <class T, class... Ts>
std::array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {std::move(x), std::move(static_cast<T>(xs))...};
}
MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
{
if(ins->name() != "gpu::convolution")
return false;
auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4);
auto channels = wei.lens()[1] * wei.lens()[0];
if(wei.lens()[0] > 64 and channels > 32768)
return false;
auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.algo == miopenConvolutionFwdAlgoWinograd)
return false;
auto op = conv.op;
return op.padding == make_array<size_t>(0, 0) and op.stride == make_array<size_t>(1, 1) and
op.dilation == make_array<size_t>(1, 1);
}
struct hip_triadd
{
std::string name() const { return "hip::triadd"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
device::add(args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
};
struct hip_triadd_relu
{
std::string name() const { return "hip::triadd_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
device::add_relu(args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
};
struct hip_add_relu struct hip_add_relu
{ {
std::string name() const { return "hip::add_relu"; } std::string name() const { return "hip::add_relu"; }
...@@ -22,20 +189,198 @@ struct hip_add_relu ...@@ -22,20 +189,198 @@ struct hip_add_relu
} }
}; };
void fuse_ops::apply(program& p) const struct find_add_relu
{ {
for(auto ins : iterator_for(p)) auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add")));
}
void apply(program& p, match::matcher_result r) const
{ {
if(ins->name() != "gpu::relu") auto add_ins = r.instructions["add"];
continue; auto ins = r.result;
auto add_ins = ins->inputs().front(); auto args = add_ins->inputs();
if(add_ins->name() != "gpu::add")
continue;
auto args = add_ins->inputs();
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_relu{}, args); if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_relu{}, args);
else if(add_ins->name() == "hip::triadd")
p.replace_instruction(ins, hip_triadd_relu{}, args);
} }
};
struct find_triadd
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"),
match::any().bind("input")));
}
void apply(program& p, match::matcher_result r) const
{
auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"];
auto ins = r.result;
auto args = add_ins->inputs();
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return;
args.insert(args.begin(), input_ins);
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(args.begin(), args.end(), is_broadcasted);
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args);
}
};
struct miopen_conv_bias
{
op::convolution op;
fusion f;
fusion::op_t conv;
fusion::op_t bias;
miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b)
: op(c), f(input)
{
conv = f.create_conv(op, weights);
bias = f.create_bias(b);
}
std::string name() const { return "gpu::conv_bias"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(5);
// TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto fargs = make_fused_args();
float alpha = 1, beta = 0;
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
return f.execute(ctx, fargs, args[0], args[4]);
}
shape compile(context& ctx)
{
f.compile(ctx);
return f.get_workspace(ctx);
}
};
struct miopen_conv_bias_relu
{
op::convolution op;
fusion f;
fusion::op_t conv;
fusion::op_t bias;
fusion::op_t relu;
miopen_conv_bias_relu(op::convolution c,
const shape& input,
const shape& weights,
const shape& b)
: op(c), f(input)
{
conv = f.create_conv(op, weights);
bias = f.create_bias(b);
relu = f.create_relu();
}
std::string name() const { return "gpu::conv_bias_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(5);
// TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto fargs = make_fused_args();
float alpha = 1, beta = 0;
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
return f.execute(ctx, fargs, args[0], args[4]);
}
shape compile(context& ctx)
{
f.compile(ctx);
return f.get_workspace(ctx);
}
};
template <class... Ms>
auto conv_bias(Ms... ms)
{
return match::name("gpu::add")(
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")),
ms...);
}
template <class Op>
void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
{
auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"];
auto ins = r.result;
auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto alloc_ins = ins->inputs().back();
auto old_ws_ins = conv_ins->inputs().at(2);
Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation
auto ws = cb.compile(ctx);
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
}
struct find_conv_bias
{
context* ctx = nullptr;
auto matcher() const
{
return conv_bias(match::none_of(match::output(match::name("gpu::relu"))));
}
void apply(program& p, match::matcher_result r) const
{
apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r));
}
};
struct find_conv_bias_relu
{
context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
void apply(program& p, match::matcher_result r) const
{
apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r));
}
};
void fuse_ops::apply(program& p) const
{
// clang-format off
match::find_matches(p, find_triadd{});
match::find_matches(p,
find_conv_bias_relu{ctx},
find_conv_bias{ctx},
find_add_relu{}
);
// clang-format on
} }
} // namespace gpu } // namespace gpu
......
...@@ -10,6 +10,8 @@ namespace device { ...@@ -10,6 +10,8 @@ namespace device {
void add(const argument& result, const argument& arg1, const argument& arg2); void add(const argument& result, const argument& arg1, const argument& arg2);
void add(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
......
...@@ -10,6 +10,11 @@ namespace device { ...@@ -10,6 +10,11 @@ namespace device {
void add_relu(const argument& result, const argument& arg1, const argument& arg2); void add_relu(const argument& result, const argument& arg1, const argument& arg2);
void add_relu(const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
......
...@@ -10,6 +10,7 @@ namespace gpu { ...@@ -10,6 +10,7 @@ namespace gpu {
struct fuse_ops struct fuse_ops
{ {
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_ops"; } std::string name() const { return "gpu::fuse_ops"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment