"doc/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "bd33a59ecffe045b55663f9b1cd9152d1be46240"
Commit d9170e2d authored by Paul's avatar Paul
Browse files

Merge branch 'master' into im2col_cpu

parents 674ea92d 9fee0fe4
CheckOptions: CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: modernize-loop-convert.MinConfidence - key: modernize-loop-convert.MinConfidence
value: risky value: risky
- key: modernize-loop-convert.NamingStyle - key: modernize-loop-convert.NamingStyle
......
...@@ -22,16 +22,6 @@ add_compile_options(-std=c++14) ...@@ -22,16 +22,6 @@ add_compile_options(-std=c++14)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings) include(EnableCompilerWarnings)
# Override clang-tidy to not find the version from hcc
# find_program(CLANG_TIDY_EXE
# NAMES
# clang-tidy
# clang-tidy-5.0
# clang-tidy-6.0
# clang-tidy-7.0
# PATHS
# /usr/local/opt/llvm/bin
# )
include(ROCMClangTidy) include(ROCMClangTidy)
rocm_enable_clang_tidy( rocm_enable_clang_tidy(
CHECKS CHECKS
...@@ -87,8 +77,11 @@ rocm_enable_clang_tidy( ...@@ -87,8 +77,11 @@ rocm_enable_clang_tidy(
) )
include(ROCMCppCheck) include(ROCMCppCheck)
rocm_enable_cppcheck( rocm_enable_cppcheck(
CHECKS CHECKS
all warning
style
performance
portability
SUPPRESS SUPPRESS
ConfigurationNotChecked ConfigurationNotChecked
unmatchedSuppression unmatchedSuppression
...@@ -96,7 +89,10 @@ rocm_enable_cppcheck( ...@@ -96,7 +89,10 @@ rocm_enable_cppcheck(
noExplicitConstructor noExplicitConstructor
passedByValue passedByValue
unusedStructMember unusedStructMember
definePrefix:*test/include/test.hpp
FORCE FORCE
RULE_FILE
${CMAKE_CURRENT_SOURCE_DIR}/cppcheck.rules
SOURCES SOURCES
src/ src/
test/ test/
......
<?xml version="1.0"?>
<rule>
<pattern> [;{}] [*] \w+? (\+\+|\-\-) ; </pattern>
<message>
<id>UnusedDeref</id>
<severity>style</severity>
<summary>Redundant * found, "*p++" is the same as "*(p++)".</summary>
</message>
</rule>
<rule>
<pattern> if \( ([!] )*?(strlen) \( \w+? \) ([>] [0] )*?\) { </pattern>
<message>
<id>StrlenEmptyString</id>
<severity>performance</severity>
<summary>Using strlen() to check if a string is empty is not efficient.</summary>
</message>
</rule>
<rule>
<pattern> [;{}] [*] \w+? (\+\+|\-\-) ; </pattern>
<message>
<id>UnusedDeref</id>
<severity>style</severity>
<summary>Redundant * found, "*p++" is the same as "*(p++)".</summary>
</message>
</rule>
<rule>
<tokenlist>define</tokenlist>
<pattern>define [0-9A-Z_^a-z]*[a-z]</pattern>
<message>
<id>defineUpperCase</id>
<severity>style</severity>
<summary>Macros must be uppercase</summary>
</message>
</rule>
<rule>
<tokenlist>define</tokenlist>
<pattern>define (MIGRAP|[^H]{6})[^H][^_]</pattern>
<message>
<id>definePrefix</id>
<severity>style</severity>
<summary>Macros must be prefixed with MIGRAPH_</summary>
</message>
</rule>
<rule>
<pattern>(memcpy|strcpy|strncpy|strcat|strncat) \(</pattern>
<message>
<id>useStlAlgorithms</id>
<severity>style</severity>
<summary>Use std::copy instead</summary>
</message>
</rule>
<rule>
<pattern>memset \(</pattern>
<message>
<id>useStlAlgorithms</id>
<severity>style</severity>
<summary>Use std::fill instead</summary>
</message>
</rule>
<rule>
<pattern>memcmp \(</pattern>
<message>
<id>useStlAlgorithms</id>
<severity>style</severity>
<summary>Use std::equal_range instead</summary>
</message>
</rule>
<rule>
<pattern>memchr \(</pattern>
<message>
<id>useStlAlgorithms</id>
<severity>style</severity>
<summary>Use std::find instead</summary>
</message>
</rule>
<rule>
<pattern>(fclose|free|hipFree) \(</pattern>
<message>
<id>useManagePointer</id>
<severity>style</severity>
<summary>Use manage pointer for resource management</summary>
</message>
</rule>
<rule>
<tokenlist>normal</tokenlist>
<pattern>! !</pattern>
<message>
<id>doubleNegative</id>
<severity>style</severity>
<summary>Double negative is always positive</summary>
</message>
</rule>
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
add_library(migraph add_library(migraph
auto_contiguous.cpp auto_contiguous.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
fwd_conv_batchnorm_rewrite.cpp
env.cpp env.cpp
generate.cpp generate.cpp
program.cpp program.cpp
......
#include <migraph/gpu/eliminate_allocation.hpp> #include <migraph/eliminate_allocation.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp> #include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
namespace migraph { namespace migraph {
namespace gpu {
void eliminate_allocation::apply(program& p) const void eliminate_allocation::apply(program& p) const
{ {
assert(alignment > 0);
std::size_t n = 0; std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->op.name() != "hip::allocate") if(ins->op.name() != allocation_op)
continue; continue;
allocs.emplace_back(ins, n); allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes(); std::size_t size = ins->get_shape().bytes();
n += size + (size % 4); std::size_t padding = (alignment - (size % alignment)) % alignment;
n += size + padding;
} }
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}}); auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}});
for(auto&& pp : allocs) for(auto&& pp : allocs)
...@@ -28,8 +27,7 @@ void eliminate_allocation::apply(program& p) const ...@@ -28,8 +27,7 @@ void eliminate_allocation::apply(program& p) const
auto ins = pp.first; auto ins = pp.first;
auto s = ins->get_shape(); auto s = ins->get_shape();
auto offset = pp.second; auto offset = pp.second;
p.replace_instruction(ins, hip_load{s, offset}, mem); p.replace_instruction(ins, load{s, offset}, mem);
} }
} }
} // namespace gpu
} // namespace migraph } // namespace migraph
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace migraph {
void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->op.name() != "batch_norm_inference")
continue;
if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) {
return arg->op.name() == "@literal";
}))
continue;
auto conv_ins = ins->arguments[0];
if(conv_ins->op.name() != "convolution")
continue;
if(conv_ins->arguments[1]->op.name() != "@literal")
continue;
// Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->arguments[1]->get_literal();
const auto& bias = ins->arguments[2]->get_literal();
const auto& mean = ins->arguments[3]->get_literal();
const auto& variance = ins->arguments[4]->get_literal();
// Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op);
auto epsilon = bn_op.epsilon;
// Get convolution weights
const auto& weights = conv_ins->arguments[1]->get_literal();
// Get convolution op
auto conv_op = conv_ins->op;
auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2,
auto gamma2,
auto bias2,
auto mean2,
auto variance2,
auto new_weights2,
auto new_bias2) {
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
});
dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2(c) = bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon));
});
});
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias);
p.replace_instruction(ins, add{}, {c, b});
}
}
} // namespace migraph
...@@ -7,14 +7,13 @@ ...@@ -7,14 +7,13 @@
namespace migraph { namespace migraph {
struct program; struct program;
namespace gpu {
struct eliminate_allocation struct eliminate_allocation
{ {
std::string allocation_op{};
std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; } std::string name() const { return "eliminate_allocation"; }
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace gpu
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <iso646.h> #include <iso646.h>
#endif #endif
#include <migraph/requires.hpp>
namespace migraph { namespace migraph {
template <class... Ts> template <class... Ts>
...@@ -15,7 +17,7 @@ using common_type = typename std::common_type<Ts...>::type; ...@@ -15,7 +17,7 @@ using common_type = typename std::common_type<Ts...>::type;
struct float_equal_fn struct float_equal_fn
{ {
template <class T> template <class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})>
static bool apply(T x, T y) static bool apply(T x, T y)
{ {
return std::isfinite(x) and std::isfinite(y) and return std::isfinite(x) and std::isfinite(y) and
...@@ -23,6 +25,12 @@ struct float_equal_fn ...@@ -23,6 +25,12 @@ struct float_equal_fn
std::nextafter(x, std::numeric_limits<T>::max()) >= y; std::nextafter(x, std::numeric_limits<T>::max()) >= y;
} }
template <class T, MIGRAPH_REQUIRES(not std::is_floating_point<T>{})>
static bool apply(T x, T y)
{
return x == y;
}
template <class T, class U> template <class T, class U>
bool operator()(T x, U y) const bool operator()(T x, U y) const
{ {
......
...@@ -5,6 +5,14 @@ ...@@ -5,6 +5,14 @@
namespace migraph { namespace migraph {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
namespace detail { namespace detail {
template <class R, class F> template <class R, class F>
...@@ -19,8 +27,48 @@ struct fix_f ...@@ -19,8 +27,48 @@ struct fix_f
} }
}; };
template <std::size_t...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <std::size_t... Xs, std::size_t... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <std::size_t N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
template <class F, std::size_t... Ns>
constexpr void repeat_c_impl(F f, seq<Ns...>)
{
swallow{(f(std::integral_constant<std::size_t, Ns>{}), 0)...};
}
} // namespace detail } // namespace detail
template <std::size_t N, class F>
constexpr void repeat_c(F f)
{
detail::repeat_c_impl(f, detail::gens<N>{});
}
/// Implements a fix-point combinator /// Implements a fix-point combinator
template <class R, class F> template <class R, class F>
detail::fix_f<R, F> fix(F f) detail::fix_f<R, F> fix(F f)
...@@ -35,7 +83,7 @@ auto fix(F f) ...@@ -35,7 +83,7 @@ auto fix(F f)
} }
template <class... Ts> template <class... Ts>
auto make_sequence(Ts... xs) auto pack(Ts... xs)
{ {
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
......
#ifndef MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct fwd_conv_batchnorm_rewrite
{
std::string name() const { return "fwd_conv_batchnorm_rewrite"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z) ...@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z)
{ {
if(z == 0) if(z == 0)
return 0; return 0;
return (2.0 / z) - 1.0; const auto max = 2048;
const double range = max / 2; // NOLINT
double result = (z % max) / range;
result -= 1;
return result;
} }
template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})> template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})>
...@@ -54,11 +58,29 @@ struct xorshf96_generator ...@@ -54,11 +58,29 @@ struct xorshf96_generator
} }
}; };
template <class T>
struct xorshift_generator
{
unsigned long x;
xorshift_generator(unsigned long seed = 0) : x(521288629ULL ^ seed) {}
constexpr T operator()() noexcept
{
x ^= x >> 12U;
x ^= x << 25U;
x ^= x >> 27U;
return normalize<T>(x * 0x2545F4914F6CDD1D);
}
};
template <class T> template <class T>
std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0) std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0)
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; });
return result; return result;
} }
......
...@@ -115,6 +115,11 @@ struct instruction ...@@ -115,6 +115,11 @@ struct instruction
} }
shape get_shape() const { return result; } shape get_shape() const { return result; }
const literal& get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
......
#ifndef GUARD_MIGRAPHLIB_ONNX_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#define GUARD_MIGRAPHLIB_ONNX_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraph/program.hpp> #include <migraph/program.hpp>
......
...@@ -579,6 +579,22 @@ struct div : binary ...@@ -579,6 +579,22 @@ struct div : binary
std::string name() const { return "div"; } std::string name() const { return "div"; }
}; };
struct load
{
shape s;
std::size_t offset = 0;
std::string name() const { return "load"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
return s;
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
return {s, args[0].data() + offset};
}
};
struct outline struct outline
{ {
shape s; shape s;
......
...@@ -2,17 +2,10 @@ ...@@ -2,17 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP #define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#include <ostream> #include <ostream>
#include <migraph/functional.hpp>
namespace migraph { namespace migraph {
struct swallow
{
template <class... Ts>
swallow(Ts&&...)
{
}
};
struct tracer struct tracer
{ {
tracer() {} tracer() {}
......
...@@ -140,7 +140,7 @@ std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff) ...@@ -140,7 +140,7 @@ std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff)
{ {
return mismatch_idx(r1, r2, [&](auto x, auto y) { return mismatch_idx(r1, r2, [&](auto x, auto y) {
auto d = abs_diff(x, y); auto d = abs_diff(x, y);
return !(d > diff && d < diff); return float_equal(d, diff);
}); });
} }
...@@ -162,10 +162,12 @@ double rms_range(R1&& r1, R2&& r2) ...@@ -162,10 +162,12 @@ double rms_range(R1&& r1, R2&& r2)
} }
template <class R1, class R2> template <class R1, class R2>
bool verify_range(R1&& r1, R2&& r2, double tolerance = 80) bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = nullptr)
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
if(out_error != nullptr)
*out_error = error;
return error <= threshold; return error <= threshold;
} }
} // namespace migraph } // namespace migraph
......
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraph/verify.hpp>
#include <migraph/argument.hpp>
namespace migraph {
inline void verify_args(const std::string& name,
const argument& cpu_arg,
const argument& gpu_arg,
double tolerance = 80)
{
visit_all(cpu_arg, gpu_arg)([&](auto cpu, auto gpu) {
double error;
if(not verify_range(cpu, gpu, tolerance, &error))
{
// TODO: Check for nans
std::cout << "FAILED: " << name << std::endl;
std::cout << "error: " << error << std::endl;
if(cpu.size() < 32)
std::cout << "cpu:" << cpu << std::endl;
if(gpu.size() < 32)
std::cout << "gpu:" << gpu << std::endl;
if(range_zero(cpu))
std::cout << "Cpu data is all zeros" << std::endl;
if(range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl;
auto idx = mismatch_idx(cpu, gpu, float_equal);
if(idx < range_distance(cpu))
{
std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx]
<< std::endl;
}
auto cpu_nan_idx = find_idx(cpu, not_finite);
if(cpu_nan_idx >= 0)
std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": "
<< cpu[cpu_nan_idx] << std::endl;
auto gpu_nan_idx = find_idx(gpu, not_finite);
if(gpu_nan_idx >= 0)
std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": "
<< gpu[gpu_nan_idx] << std::endl;
std::cout << std::endl;
}
});
}
} // namespace migraph
#endif
...@@ -5,61 +5,110 @@ ...@@ -5,61 +5,110 @@
#include <migraph/gpu/target.hpp> #include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraph/generate.hpp>
#include <migraph/verify.hpp> #include <migraph/verify_args.hpp>
#include <migraph/instruction.hpp>
migraph::argument run_cpu(const std::string& file) template <class T>
auto get_hash(const T& x)
{ {
auto p = migraph::parse_onnx(file); return std::hash<T>{}(x);
}
template <class F>
migraph::argument run_cpu(F f)
{
auto p = f();
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
migraph::program::parameter_map m; migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::generate_argument(x.second); m[x.first] = migraph::generate_argument(x.second, get_hash(x.first));
} }
auto out = p.eval(m); auto out = p.eval(m);
std::cout << p << std::endl; std::cout << p << std::endl;
return out; return out;
} }
migraph::argument run_gpu(const std::string& file) template <class F>
migraph::argument run_gpu(F f)
{ {
auto p = migraph::parse_onnx(file); auto p = f();
p.compile(migraph::gpu::target{}); p.compile(migraph::gpu::target{});
migraph::program::parameter_map m; migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second)); m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, get_hash(x.first)));
} }
auto out = migraph::gpu::from_gpu(p.eval(m)); auto out = migraph::gpu::from_gpu(p.eval(m));
std::cout << p << std::endl; std::cout << p << std::endl;
return migraph::gpu::from_gpu(out); return migraph::gpu::from_gpu(out);
} }
template <class F>
void verify_program(const std::string& name, F f, double tolerance = 100)
{
auto x = run_cpu(f);
auto y = run_gpu(f);
migraph::verify_args(name, x, y, tolerance);
}
void verify_instructions(const migraph::program& prog, double tolerance = 80)
{
for(auto&& ins : prog)
{
if(ins.op.name().front() == '@')
continue;
if(ins.op.name() == "broadcast")
continue;
if(ins.op.name() == "transpose")
continue;
if(ins.op.name() == "reshape")
continue;
auto create_program = [&] {
migraph::program p;
std::vector<migraph::instruction_ref> inputs;
for(auto&& arg : ins.arguments)
{
if(arg->op.name() == "@literal")
inputs.push_back(p.add_literal(arg->lit));
else
inputs.push_back(
p.add_parameter(std::to_string(inputs.size()), arg->get_shape()));
}
p.add_instruction(ins.op, inputs);
return p;
};
try
{
std::cout << "Verify: " << ins.op.name() << std::endl;
std::cout << create_program() << std::endl;
verify_program(ins.op.name(), create_program, tolerance);
}
catch(...)
{
std::cout << "Instruction " << ins.op.name() << " threw an exception." << std::endl;
throw;
}
}
}
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 1) std::vector<std::string> args(argv + 1, argv + argc);
if(not args.empty())
{ {
std::string file = argv[1]; std::string file = args.front();
auto p = migraph::parse_onnx(file); auto p = migraph::parse_onnx(file);
std::cout << p << std::endl; std::cout << p << std::endl;
auto x = run_cpu(file); if(std::any_of(args.begin(), args.end(), [](const auto& s) { return s == "-i"; }))
auto y = run_gpu(file); {
visit_all(x, y)([](auto cpu, auto gpu) { verify_instructions(p);
if(migraph::verify_range(cpu, gpu, 100)) }
{ else
std::cout << "Passed" << std::endl; {
} verify_program(file, [&] { return migraph::parse_onnx(file); });
else }
{
std::cout << "Not equal" << std::endl;
std::cout << "cpu:" << std::endl;
std::cout << cpu << std::endl;
std::cout << "gpu:" << std::endl;
std::cout << gpu << std::endl;
}
});
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment