Commit 99ee76c0 authored by Paul's avatar Paul
Browse files

Merge branch 'master' into mem_color_separate_literal-master

parents 85c2c29d f9f4f713
CheckOptions: CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: modernize-loop-convert.MinConfidence - key: modernize-loop-convert.MinConfidence
value: risky value: risky
- key: modernize-loop-convert.NamingStyle - key: modernize-loop-convert.NamingStyle
......
...@@ -22,16 +22,6 @@ add_compile_options(-std=c++14) ...@@ -22,16 +22,6 @@ add_compile_options(-std=c++14)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings) include(EnableCompilerWarnings)
# Override clang-tidy to not find the version from hcc
# find_program(CLANG_TIDY_EXE
# NAMES
# clang-tidy
# clang-tidy-5.0
# clang-tidy-6.0
# clang-tidy-7.0
# PATHS
# /usr/local/opt/llvm/bin
# )
include(ROCMClangTidy) include(ROCMClangTidy)
rocm_enable_clang_tidy( rocm_enable_clang_tidy(
CHECKS CHECKS
...@@ -87,8 +77,11 @@ rocm_enable_clang_tidy( ...@@ -87,8 +77,11 @@ rocm_enable_clang_tidy(
) )
include(ROCMCppCheck) include(ROCMCppCheck)
rocm_enable_cppcheck( rocm_enable_cppcheck(
CHECKS CHECKS
all warning
style
performance
portability
SUPPRESS SUPPRESS
ConfigurationNotChecked ConfigurationNotChecked
unmatchedSuppression unmatchedSuppression
...@@ -96,7 +89,10 @@ rocm_enable_cppcheck( ...@@ -96,7 +89,10 @@ rocm_enable_cppcheck(
noExplicitConstructor noExplicitConstructor
passedByValue passedByValue
unusedStructMember unusedStructMember
definePrefix:*test/include/test.hpp
FORCE FORCE
RULE_FILE
${CMAKE_CURRENT_SOURCE_DIR}/cppcheck.rules
SOURCES SOURCES
src/ src/
test/ test/
......
<?xml version="1.0"?>
<rule>
<pattern> [;{}] [*] \w+? (\+\+|\-\-) ; </pattern>
<message>
<id>UnusedDeref</id>
<severity>style</severity>
<summary>Redundant * found, "*p++" is the same as "*(p++)".</summary>
</message>
</rule>
<rule>
<pattern> if \( ([!] )*?(strlen) \( \w+? \) ([>] [0] )*?\) { </pattern>
<message>
<id>StrlenEmptyString</id>
<severity>performance</severity>
<summary>Using strlen() to check if a string is empty is not efficient.</summary>
</message>
</rule>
<rule>
<pattern> [;{}] [*] \w+? (\+\+|\-\-) ; </pattern>
<message>
<id>UnusedDeref</id>
<severity>style</severity>
<summary>Redundant * found, "*p++" is the same as "*(p++)".</summary>
</message>
</rule>
<rule>
<tokenlist>define</tokenlist>
<pattern>define [0-9A-Z_^a-z]*[a-z]</pattern>
<message>
<id>defineUpperCase</id>
<severity>style</severity>
<summary>Macros must be uppercase</summary>
</message>
</rule>
<rule>
<tokenlist>define</tokenlist>
<pattern>define (MIGRAP|[^H]{6})[^H][^_]</pattern>
<message>
<id>definePrefix</id>
<severity>style</severity>
<summary>Macros must be prefixed with MIGRAPH_</summary>
</message>
</rule>
<rule>
<pattern>(memcpy|strcpy|strncpy|strcat|strncat) \(</pattern>
<message>
<id>useStlAlgorithms</id>
<severity>style</severity>
<summary>Use std::copy instead</summary>
</message>
</rule>
<rule>
<pattern>memset \(</pattern>
<message>
<id>useStlAlgorithms</id>
<severity>style</severity>
<summary>Use std::fill instead</summary>
</message>
</rule>
<rule>
<pattern>memcmp \(</pattern>
<message>
<id>useStlAlgorithms</id>
<severity>style</severity>
<summary>Use std::equal_range instead</summary>
</message>
</rule>
<rule>
<pattern>memchr \(</pattern>
<message>
<id>useStlAlgorithms</id>
<severity>style</severity>
<summary>Use std::find instead</summary>
</message>
</rule>
<rule>
<pattern>(fclose|free|hipFree) \(</pattern>
<message>
<id>useManagePointer</id>
<severity>style</severity>
<summary>Use manage pointer for resource management</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern>! !</pattern>
<message>
<id>doubleNegative</id>
<severity>style</severity>
<summary>Double negative is always positive</summary>
</message>
</rule>
sphinx sphinx==1.6.2
breathe==4.9.1 breathe==4.9.1
# git+https://github.com/arximboldi/breathe@fix-node-parent # git+https://github.com/arximboldi/breathe@fix-node-parent
...@@ -7,6 +7,7 @@ add_library(migraph ...@@ -7,6 +7,7 @@ add_library(migraph
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
env.cpp env.cpp
generate.cpp generate.cpp
instruction.cpp
program.cpp program.cpp
shape.cpp shape.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
......
...@@ -10,7 +10,7 @@ void auto_contiguous::apply(program& p) const ...@@ -10,7 +10,7 @@ void auto_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
shape s = ins->result; shape s = ins->get_shape();
if(not s.standard()) if(not s.standard())
{ {
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins); auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
......
...@@ -17,16 +17,16 @@ void dead_code_elimination::apply(program& p) const ...@@ -17,16 +17,16 @@ void dead_code_elimination::apply(program& p) const
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
// Skip instruction with empty shape as output unless its a builtin // Skip instruction with empty shape as output unless its a builtin
if(i->result.elements() == 0 and not(i->op.name().front() == '@')) if(i->get_shape().elements() == 0 and not(i->name().front() == '@'))
continue; continue;
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
break; break;
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
assert(p.has_instruction(leaf)); assert(p.has_instruction(leaf));
if(leaf->output.empty()) if(leaf->outputs().empty())
{ {
auto args = leaf->arguments; auto args = leaf->inputs();
leaf->clear_arguments(); leaf->clear_arguments();
p.move_instruction(leaf, p.end()); p.move_instruction(leaf, p.end());
for(auto arg : args) for(auto arg : args)
......
...@@ -19,7 +19,7 @@ void eliminate_allocation::apply(program& p) const ...@@ -19,7 +19,7 @@ void eliminate_allocation::apply(program& p) const
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != allocation_op) if(ins->name() != allocation_op)
continue; continue;
allocs.emplace_back(ins, n); allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes(); std::size_t size = ins->get_shape().bytes();
......
...@@ -27,19 +27,19 @@ void eliminate_contiguous::apply(program& p) const ...@@ -27,19 +27,19 @@ void eliminate_contiguous::apply(program& p) const
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->arguments; auto args = ins->inputs();
for(auto arg : ins->arguments) for(auto arg : ins->inputs())
{ {
// TODO: Pass in names for the operator in the constructor instead // TODO: Pass in names for the operator in the constructor instead
// of using ends_with // of using ends_with
if(ends_with(arg->op.name(), "contiguous")) if(ends_with(arg->name(), "contiguous"))
{ {
auto new_args = args; auto new_args = args;
auto prev = arg->arguments.front(); auto prev = arg->inputs().front();
replace(new_args, arg, prev); replace(new_args, arg, prev);
if(try_compute_shape(ins->op, new_args)) if(try_compute_shape(ins->get_operator(), new_args))
{ {
replace_argument(ins, arg, prev); instruction::replace_argument(ins, arg, prev);
} }
} }
} }
......
...@@ -10,30 +10,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -10,30 +10,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != "batch_norm_inference") if(ins->name() != "batch_norm_inference")
continue; continue;
if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) { if(not std::all_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->op.name() == "@literal"; return arg->name() == "@literal";
})) }))
continue; continue;
auto conv_ins = ins->arguments[0]; auto conv_ins = ins->inputs()[0];
if(conv_ins->op.name() != "convolution") if(conv_ins->name() != "convolution")
continue; continue;
if(conv_ins->arguments[1]->op.name() != "@literal") if(conv_ins->inputs()[1]->name() != "@literal")
continue; continue;
// Get scale, bias, mean, variance from instruction_ref // Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->arguments[1]->get_literal(); const auto& gamma = ins->inputs()[1]->get_literal();
const auto& bias = ins->arguments[2]->get_literal(); const auto& bias = ins->inputs()[2]->get_literal();
const auto& mean = ins->arguments[3]->get_literal(); const auto& mean = ins->inputs()[3]->get_literal();
const auto& variance = ins->arguments[4]->get_literal(); const auto& variance = ins->inputs()[4]->get_literal();
// Get epsilon // Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op); auto bn_op = any_cast<batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon; auto epsilon = bn_op.epsilon;
// Get convolution weights // Get convolution weights
const auto& weights = conv_ins->arguments[1]->get_literal(); const auto& weights = conv_ins->inputs()[1]->get_literal();
// Get convolution op // Get convolution op
auto conv_op = conv_ins->op; auto conv_op = conv_ins->get_operator();
auto weights_lens = weights.get_shape().lens(); auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens(); auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()}; argument new_weights{weights.get_shape()};
...@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Replace convolution instruction with updated weights // Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()}); auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()}); auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights}); auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->inputs()[0], l_weights});
auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias); auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias);
p.replace_instruction(ins, add{}, {c, b}); p.replace_instruction(ins, add{}, {c, b});
} }
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <iso646.h> #include <iso646.h>
#endif #endif
#include <migraph/requires.hpp>
namespace migraph { namespace migraph {
template <class... Ts> template <class... Ts>
...@@ -15,7 +17,7 @@ using common_type = typename std::common_type<Ts...>::type; ...@@ -15,7 +17,7 @@ using common_type = typename std::common_type<Ts...>::type;
struct float_equal_fn struct float_equal_fn
{ {
template <class T> template <class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})>
static bool apply(T x, T y) static bool apply(T x, T y)
{ {
return std::isfinite(x) and std::isfinite(y) and return std::isfinite(x) and std::isfinite(y) and
...@@ -23,6 +25,12 @@ struct float_equal_fn ...@@ -23,6 +25,12 @@ struct float_equal_fn
std::nextafter(x, std::numeric_limits<T>::max()) >= y; std::nextafter(x, std::numeric_limits<T>::max()) >= y;
} }
template <class T, MIGRAPH_REQUIRES(not std::is_floating_point<T>{})>
static bool apply(T x, T y)
{
return x == y;
}
template <class T, class U> template <class T, class U>
bool operator()(T x, U y) const bool operator()(T x, U y) const
{ {
......
...@@ -12,7 +12,7 @@ constexpr T normalize(unsigned long z) ...@@ -12,7 +12,7 @@ constexpr T normalize(unsigned long z)
{ {
if(z == 0) if(z == 0)
return 0; return 0;
const auto max = 2048; const auto max = 32;
const double range = max / 2; // NOLINT const double range = max / 2; // NOLINT
double result = (z % max) / range; double result = (z % max) / range;
result -= 1; result -= 1;
......
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp> #include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -18,126 +16,60 @@ struct instruction ...@@ -18,126 +16,60 @@ struct instruction
{ {
instruction() {} instruction() {}
instruction(operation o, shape r, std::vector<instruction_ref> args) instruction(operation o, shape r, std::vector<instruction_ref> args);
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction(literal l) : op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) {} instruction(literal l);
// internal void replace(const shape& r);
void replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = std::move(o);
replace(r);
replace(std::move(args));
}
void replace(const shape& r) void recompute_shape();
{
if(r != result)
{
result = r;
for(auto&& ins : output)
{
assert(ins->op.name().front() != '@');
ins->recompute_shape();
}
}
}
void recompute_shape() { replace(compute_shape(op, arguments)); } void clear_arguments();
// internal friend bool operator==(const instruction& i, instruction_ref ref);
void replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
// internal bool valid(instruction_ref start) const;
void replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
}
void clear_arguments() bool valid() const;
{
for(auto&& arg : arguments)
{
arg->remove_output(*this);
}
arguments.clear();
}
friend bool operator==(const instruction& i, instruction_ref ref) shape get_shape() const;
{ const literal& get_literal() const;
return std::addressof(i) == std::addressof(*ref);
}
bool valid(instruction_ref start) const const operation& get_operator() const;
{
return valid() && std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->output.begin(), i->output.end(), *this);
return self != i->output.end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool valid() const std::string name() const;
{
shape computed;
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{
computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed &&
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->arguments.begin(), i->arguments.end(), *this) !=
i->arguments.end();
});
}
shape get_shape() const { return result; } const std::vector<instruction_ref>& inputs() const;
const literal& get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } const std::vector<instruction_ref>& outputs() const;
friend bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } friend bool operator==(instruction_ref ref, const instruction& i);
friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); } friend bool operator!=(const instruction& i, instruction_ref ref);
void add_output(instruction_ref ins) friend bool operator!=(instruction_ref ref, const instruction& i);
{
if(std::find(output.begin(), output.end(), ins) == output.end()) void add_output(instruction_ref ins);
output.push_back(ins);
}
template <class T> template <class T>
void remove_output(const T& ins) void remove_output(const T& ins);
{
migraph::erase(output, ins); static void backreference(instruction_ref ref);
}
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins);
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
private:
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args);
// internal
void replace(std::vector<instruction_ref> args);
// internal
void replace_argument(instruction_ref old, instruction_ref new_ins);
operation op; operation op;
shape result; shape result;
...@@ -146,29 +78,6 @@ struct instruction ...@@ -146,29 +78,6 @@ struct instruction
literal lit; literal lit;
}; };
inline void backreference(instruction_ref ref)
{
for(auto&& arg : ref->arguments)
arg->add_output(ref);
}
inline void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
// TODO: Move to a cpp file
// TODO: Use const ref for vector
inline shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->result; });
return op.compute_shape(shapes);
}
} // namespace migraph } // namespace migraph
namespace std { namespace std {
......
#ifndef GUARD_MIGRAPHLIB_ONNX_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#define GUARD_MIGRAPHLIB_ONNX_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraph/program.hpp> #include <migraph/program.hpp>
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraph/auto_any_cast.hpp>
...@@ -27,13 +28,16 @@ struct operation ...@@ -27,13 +28,16 @@ struct operation
/// exception. /// exception.
shape compute_shape(const std::vector<shape>& input) const; shape compute_shape(const std::vector<shape>& input) const;
/** /**
* @brief This performs the operation's computation * @brief This performs the operation's computation.
*
* This method can be optional when the operation is only used as a placeholder to be lowered
* later on.
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`. * `shape` of the `argument`.
* @param input This is the `argument` result from the previous instuction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
*/ */
...@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name()) ...@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream } // namespace operation_stream
template <class T>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPH_THROW("Not computable: " + name);
}
template <class T> template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
/* /*
......
...@@ -41,11 +41,6 @@ struct batch_norm_inference ...@@ -41,11 +41,6 @@ struct batch_norm_inference
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct convolution struct convolution
...@@ -115,11 +110,6 @@ struct convolution ...@@ -115,11 +110,6 @@ struct convolution
} }
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const convolution& op) friend std::ostream& operator<<(std::ostream& os, const convolution& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -131,6 +121,46 @@ struct convolution ...@@ -131,6 +121,46 @@ struct convolution
} }
}; };
struct im2col
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
enum padding_mode_t
{
default_, // NOLINT
same,
valid
};
std::string name() const { return "im2col"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto input = inputs[0];
auto weights = inputs[1];
auto batch_size = input.lens()[0];
auto input_channels = weights.lens()[1];
auto kernel_height = weights.lens()[2];
auto kernel_width = weights.lens()[3];
check_shapes{inputs, *this}.has(2);
if(batch_size != 1)
MIGRAPH_THROW("im2col only support batch_size 1");
auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
stride[0] +
1));
auto output_width = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
stride[1] +
1));
auto channels_col = kernel_height * kernel_width * input_channels;
return {input.type(), {output_height * output_width, channels_col}};
}
};
struct pooling struct pooling
{ {
std::string mode = "average"; std::string mode = "average";
...@@ -166,11 +196,6 @@ struct pooling ...@@ -166,11 +196,6 @@ struct pooling
}}; }};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const pooling& op) friend std::ostream& operator<<(std::ostream& os, const pooling& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -191,11 +216,6 @@ struct activation ...@@ -191,11 +216,6 @@ struct activation
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const activation& op) friend std::ostream& operator<<(std::ostream& os, const activation& op)
{ {
os << op.name() << ":" << op.mode; os << op.name() << ":" << op.mode;
...@@ -260,10 +280,6 @@ struct contiguous ...@@ -260,10 +280,6 @@ struct contiguous
} }
return {t, lens}; return {t, lens};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct reshape struct reshape
...@@ -304,12 +320,10 @@ struct reshape ...@@ -304,12 +320,10 @@ struct reshape
MIGRAPH_THROW("Wrong number of elements for reshape"); MIGRAPH_THROW("Wrong number of elements for reshape");
return s; return s;
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -337,11 +351,6 @@ struct gemm ...@@ -337,11 +351,6 @@ struct gemm
return {t, {a.lens()[0], b.lens()[1]}}; return {t, {a.lens()[0], b.lens()[1]}};
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const gemm& op) friend std::ostream& operator<<(std::ostream& os, const gemm& op)
{ {
os << op.name() << "["; os << op.name() << "[";
...@@ -357,10 +366,6 @@ struct unary ...@@ -357,10 +366,6 @@ struct unary
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct identity : unary struct identity : unary
...@@ -408,11 +413,6 @@ struct atan : unary ...@@ -408,11 +413,6 @@ struct atan : unary
std::string name() const { return "atan"; } std::string name() const { return "atan"; }
}; };
struct softmax : unary
{
std::string name() const { return "softmax"; }
};
struct tanh : unary struct tanh : unary
{ {
std::string name() const { return "tanh"; } std::string name() const { return "tanh"; }
...@@ -428,6 +428,16 @@ struct neg : unary ...@@ -428,6 +428,16 @@ struct neg : unary
std::string name() const { return "neg"; } std::string name() const { return "neg"; }
}; };
struct softmax
{
std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).only_dims(4);
return inputs.at(0);
}
};
struct flatten struct flatten
{ {
uint64_t axis = 0; uint64_t axis = 0;
...@@ -508,10 +518,6 @@ struct binary ...@@ -508,10 +518,6 @@ struct binary
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
}
}; };
struct add : binary struct add : binary
......
...@@ -3,19 +3,10 @@ ...@@ -3,19 +3,10 @@
#include <algorithm> #include <algorithm>
#include <initializer_list> #include <initializer_list>
#include <migraph/rank.hpp>
namespace migraph { namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
namespace detail { namespace detail {
template <class String, class T> template <class String, class T>
......
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
namespace migraph {
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment