Commit 98fd5e1d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into eliminate-more-contiguous

parents f7a6d87f a1c7e7a5
...@@ -125,6 +125,8 @@ rocm_enable_cppcheck( ...@@ -125,6 +125,8 @@ rocm_enable_cppcheck(
functionConst:*program.* functionConst:*program.*
shadowFunction shadowFunction
shadowVar shadowVar
shadowVariable
unsafeClassDivZero
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
FORCE FORCE
INCONCLUSIVE INCONCLUSIVE
......
...@@ -99,13 +99,14 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build -> ...@@ -99,13 +99,14 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build ->
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-5.0 -style=file {} | diff - {}\' | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-5.0 -style=file {} | diff - {}\'
''' '''
} }
}, clang: rocmnode('vega') { cmake_build -> }, clang_debug: rocmnode('vega') { cmake_build ->
stage('Clang Debug') { stage('Clang Debug') {
// TODO: Enanle integer // TODO: Enable integer
def sanitizers = "undefined" def sanitizers = "undefined"
def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("hcc", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("hcc", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
} }
}, clang_release: rocmnode('vega') { cmake_build ->
stage('Clang Release') { stage('Clang Release') {
cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release") cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release")
} }
......
pfultz2/rocm-recipes pfultz2/rocm-recipes
danmar/cppcheck@8aa68ee297c2d9ebadf5bcfd00c66ea8d9291e35 -DHAVE_RULES=1 danmar/cppcheck@ef714225bb31e9a76ac2484796763572386955ae -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@2490e42baa7d90458f0632fd9fbead2d395f41b9 ROCm-Developer-Tools/HIP@2490e42baa7d90458f0632fd9fbead2d395f41b9
python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
...@@ -28,10 +28,32 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -28,10 +28,32 @@ inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPHX_DRIVER_STATIC static #define MIGRAPHX_DRIVER_STATIC static
#endif #endif
template <class T>
using bare = std::remove_cv_t<std::remove_reference_t<T>>;
namespace detail {
template <class T>
auto is_container(int, T&& x) -> decltype(x.insert(x.end(), *x.begin()), std::true_type{});
template <class T>
std::false_type is_container(float, T&&);
} // namespace detail
template <class T>
struct is_container : decltype(detail::is_container(int(0), std::declval<T>()))
{
};
template <class T>
using is_multi_value =
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
template <class T> template <class T>
struct value_parser struct value_parser
{ {
template <MIGRAPHX_REQUIRES(not std::is_enum<T>{})> template <MIGRAPHX_REQUIRES(not std::is_enum<T>{} and not is_multi_value<T>{})>
static T apply(const std::string& x) static T apply(const std::string& x)
{ {
T result; T result;
...@@ -43,7 +65,7 @@ struct value_parser ...@@ -43,7 +65,7 @@ struct value_parser
return result; return result;
} }
template <MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <MIGRAPHX_REQUIRES(std::is_enum<T>{} and not is_multi_value<T>{})>
static T apply(const std::string& x) static T apply(const std::string& x)
{ {
std::ptrdiff_t i; std::ptrdiff_t i;
...@@ -54,6 +76,15 @@ struct value_parser ...@@ -54,6 +76,15 @@ struct value_parser
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse: " + x);
return static_cast<T>(i); return static_cast<T>(i);
} }
template <MIGRAPHX_REQUIRES(is_multi_value<T>{} and not std::is_enum<T>{})>
static T apply(const std::string& x)
{
T result;
using value_type = typename T::value_type;
result.insert(result.end(), value_parser<value_type>::apply(x));
return result;
}
}; };
struct argument_parser struct argument_parser
...@@ -69,6 +100,18 @@ struct argument_parser ...@@ -69,6 +100,18 @@ struct argument_parser
unsigned nargs = 1; unsigned nargs = 1;
}; };
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string_range(x);
}
template <class T, MIGRAPHX_REQUIRES(not is_multi_value<T>{})>
std::string as_string_value(const T& x)
{
return to_string(x);
}
template <class T, class... Fs> template <class T, class... Fs>
void operator()(T& x, const std::vector<std::string>& flags, Fs... fs) void operator()(T& x, const std::vector<std::string>& flags, Fs... fs)
{ {
...@@ -81,7 +124,7 @@ struct argument_parser ...@@ -81,7 +124,7 @@ struct argument_parser
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>(); arg.type = migraphx::get_type_name<T>();
arg.default_value = to_string(x); arg.default_value = as_string_value(x);
migraphx::each_args([&](auto f) { f(x, arg); }, fs...); migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
} }
...@@ -127,7 +170,7 @@ struct argument_parser ...@@ -127,7 +170,7 @@ struct argument_parser
MIGRAPHX_DRIVER_STATIC auto append() MIGRAPHX_DRIVER_STATIC auto append()
{ {
return write_action([](auto&, auto& x, auto& params) { return write_action([](auto&, auto& x, auto& params) {
using type = typename decltype(params)::value_type; using type = typename bare<decltype(params)>::value_type;
std::transform(params.begin(), std::transform(params.begin(),
params.end(), params.end(),
std::inserter(x, x.end()), std::inserter(x, x.end()),
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
...@@ -80,11 +81,13 @@ struct compiler ...@@ -80,11 +81,13 @@ struct compiler
{ {
loader l; loader l;
bool gpu = true; bool gpu = true;
std::vector<std::string> fill1;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
l.parse(ap); l.parse(ap);
ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true)); ap(gpu, {"--gpu"}, ap.help("Compile on the gpu"), ap.set_value(true));
ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false)); ap(gpu, {"--cpu"}, ap.help("Compile on the cpu"), ap.set_value(false));
ap(fill1, {"--fill1"}, ap.help("Fill parameter with 1s"), ap.append());
} }
program compile() program compile()
...@@ -94,7 +97,14 @@ struct compiler ...@@ -94,7 +97,14 @@ struct compiler
return p; return p;
} }
auto params(const program& p) { return create_param_map(p, gpu); } auto params(const program& p)
{
program::parameter_map m;
for(auto&& s : fill1)
m[s] = fill_argument(p.get_parameter_shape(s), 1);
fill_param_map(m, p, gpu);
return m;
}
}; };
struct read : command<read> struct read : command<read>
...@@ -109,6 +119,19 @@ struct read : command<read> ...@@ -109,6 +119,19 @@ struct read : command<read>
} }
}; };
struct params : command<params>
{
loader l;
void parse(argument_parser& ap) { l.parse(ap); }
void run()
{
auto p = l.load();
for(auto&& param : p.get_parameter_shapes())
std::cout << param.first << ": " << param.second << std::endl;
}
};
struct verify : command<verify> struct verify : command<verify>
{ {
loader l; loader l;
......
...@@ -11,6 +11,23 @@ namespace migraphx { ...@@ -11,6 +11,23 @@ namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu)
{
for(auto&& x : p.get_parameter_shapes())
{
argument& arg = m[x.first];
if(arg.empty())
arg = generate_argument(x.second);
#ifdef HAVE_GPU
if(gpu)
arg = gpu::to_gpu(arg);
#else
(void)gpu;
#endif
}
return m;
}
program::parameter_map create_param_map(const program& p, bool gpu) program::parameter_map create_param_map(const program& p, bool gpu)
{ {
program::parameter_map m; program::parameter_map m;
......
...@@ -7,6 +7,7 @@ namespace migraphx { ...@@ -7,6 +7,7 @@ namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
program::parameter_map fill_param_map(program::parameter_map& m, const program& p, bool gpu);
program::parameter_map create_param_map(const program& p, bool gpu = true); program::parameter_map create_param_map(const program& p, bool gpu = true);
void compile_program(program& p, bool gpu = true); void compile_program(program& p, bool gpu = true);
......
...@@ -3,6 +3,17 @@ ...@@ -3,6 +3,17 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
argument fill_argument(shape s, unsigned long value)
{
argument result;
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = fill_tensor_data<type>(s, value);
result = {s, [v]() mutable { return reinterpret_cast<char*>(v.data()); }};
});
return result;
}
argument generate_argument(shape s, unsigned long seed) argument generate_argument(shape s, unsigned long seed)
{ {
argument result; argument result;
......
...@@ -36,7 +36,7 @@ struct argument : raw_data<argument> ...@@ -36,7 +36,7 @@ struct argument : raw_data<argument>
} }
/// Provides a raw pointer to the data /// Provides a raw pointer to the data
std::function<char*()> data; std::function<char*()> data = nullptr;
/// Whether data is available /// Whether data is available
bool empty() const { return not data; } bool empty() const { return not data; }
......
...@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed ...@@ -87,6 +87,16 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
return result; return result;
} }
template <class T>
std::vector<T> fill_tensor_data(const migraphx::shape& s, unsigned long value = 0)
{
std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), [=] { return value; });
return result;
}
argument fill_argument(shape s, unsigned long value = 0);
argument generate_argument(shape s, unsigned long seed = 0); argument generate_argument(shape s, unsigned long seed = 0);
literal generate_literal(shape s, unsigned long seed = 0); literal generate_literal(shape s, unsigned long seed = 0);
......
...@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name) ...@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result != ctx.not_found()) if(result != ctx.not_found())
ctx.instructions.emplace(name, ins); ctx.instructions[name] = ins;
return result; return result;
}); });
} }
...@@ -374,14 +374,14 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ...@@ -374,14 +374,14 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
} }
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins->outputs().front(); return ins->outputs().front();
return ctx.not_found(); return ctx.not_found();
} }
MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins; return ins;
...@@ -392,7 +392,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) ...@@ -392,7 +392,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
inline auto used_once_recursive(std::size_t depth) inline auto used_once_recursive(std::size_t depth)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once // Used once
if(start->outputs().size() == 1) if(start->outputs().size() == 1)
return start; return start;
...@@ -427,7 +427,7 @@ inline auto used_once_recursive(std::size_t depth) ...@@ -427,7 +427,7 @@ inline auto used_once_recursive(std::size_t depth)
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); } MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().empty() and ins != std::prev(ctx.not_found())) if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins; return ins;
...@@ -482,7 +482,7 @@ inline auto nargs(std::size_t n) ...@@ -482,7 +482,7 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i) inline auto arg(std::size_t i)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
if(i < ins->inputs().size()) if(i < ins->inputs().size())
return ins->inputs()[i]; return ins->inputs()[i];
return ctx.not_found(); return ctx.not_found();
......
...@@ -30,23 +30,29 @@ struct binary : op_name<Derived> ...@@ -30,23 +30,29 @@ struct binary : op_name<Derived>
argument result{output_shape}; argument result{output_shape};
auto s1 = args[0].get_shape(); auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape(); auto s2 = args[1].get_shape();
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { if(s1 == s2 and s1.packed())
if(s1 == s2 and input1.get_shape().packed() and input2.get_shape().packed()) {
{ shape std_shape{s1.type(), s1.lens()};
argument std_result{std_shape, result.data()};
argument std_arg0{std_shape, args[0].data()};
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
} });
else }
{ else
{
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()( output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end())); input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
}); });
} });
}); }
return result; return result;
} }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct capture
{
std::size_t ins_index;
std::function<void(std::size_t ins_index, std::vector<argument>)> f{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ins_index, "ins_index"));
}
std::string name() const { return "capture"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
{
if(f)
{
f(ins_index, args);
}
else
{
MIGRAPHX_THROW("CAPTURE: callback function is not callable!");
}
return args.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quant_convolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
}
std::string name() const { return "quant_convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
// all input type must be int8_type and output is float_type
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t");
}
t = shape::int32_type;
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quant_dot
{
int32_t alpha = 1;
int32_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(as_number(self.alpha), "alpha"), f(as_number(self.beta), "beta"));
}
std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
}
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("QUANT_DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{
MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
// k be multiple of 4
if((a.lens()[dim_1] % 4) != 0)
{
MIGRAPHX_THROW("QUANT_DOT: size of A {" + to_string_range(a.lens()) + "} and B {" +
to_string_range(b.lens()) + "} must be multiple of 4 for int8 type");
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ROUND_HPP
#define MIGRAPHX_GUARD_OPERATORS_ROUND_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct round : unary<round>
{
auto apply() const
{
return [](auto x) { return std::round(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -27,26 +27,34 @@ struct unary : op_name<Derived> ...@@ -27,26 +27,34 @@ struct unary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { auto in_shape = args[0].get_shape();
args[0].visit([&](auto input) { if(in_shape.packed())
if(input.get_shape().packed()) {
{ shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
argument arg_in{std_in_shape, args[0].data()};
argument arg_out{std_out_shape, result.data()};
arg_out.visit([&](auto output) {
arg_in.visit([&](auto input) {
std::transform(input.begin(), std::transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
return result;
}
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
}); });
return result;
}); });
}); }
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
});
});
}
return result; return result;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
...@@ -45,6 +46,8 @@ ...@@ -45,6 +46,8 @@
#include <migraphx/op/outline.hpp> #include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp> #include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_sum.hpp> #include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
...@@ -53,6 +56,7 @@ ...@@ -53,6 +56,7 @@
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp> #include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_output.hpp> #include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
......
...@@ -15,6 +15,15 @@ struct program; ...@@ -15,6 +15,15 @@ struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names); void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog); void quantize(program& prog);
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names);
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -23,9 +23,10 @@ using bool_c = std::integral_constant<bool, B>; ...@@ -23,9 +23,10 @@ using bool_c = std::integral_constant<bool, B>;
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPHX_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#else #else
#define MIGRAPHX_REQUIRES(...) \ #define MIGRAPHX_REQUIRES(...) \
bool MIGRAPHX_REQUIRES_VAR() = true, \ long MIGRAPHX_REQUIRES_VAR() = __LINE__, \
typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() && (migraphx::and_<__VA_ARGS__>{})), \ typename std::enable_if<(MIGRAPHX_REQUIRES_VAR() == __LINE__ && \
(migraphx::and_<__VA_ARGS__>{})), \
int>::type = 0 int>::type = 0
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment