"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "3d62579a6a237223b4b8158b4b03c9ae10b6abf8"
Unverified Commit 460f33b8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into eliminate-more-contiguous

parents f7a6d87f 3ab91a79
...@@ -125,6 +125,8 @@ rocm_enable_cppcheck( ...@@ -125,6 +125,8 @@ rocm_enable_cppcheck(
functionConst:*program.* functionConst:*program.*
shadowFunction shadowFunction
shadowVar shadowVar
shadowVariable
unsafeClassDivZero
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
FORCE FORCE
INCONCLUSIVE INCONCLUSIVE
......
...@@ -99,13 +99,14 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build -> ...@@ -99,13 +99,14 @@ rocmtest tidy: rocmnode('rocmtest') { cmake_build ->
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-5.0 -style=file {} | diff - {}\' | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-5.0 -style=file {} | diff - {}\'
''' '''
} }
}, clang: rocmnode('vega') { cmake_build -> }, clang_debug: rocmnode('vega') { cmake_build ->
stage('Clang Debug') { stage('Clang Debug') {
// TODO: Enanle integer // TODO: Enable integer
def sanitizers = "undefined" def sanitizers = "undefined"
def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("hcc", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("hcc", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
} }
}, clang_release: rocmnode('vega') { cmake_build ->
stage('Clang Release') { stage('Clang Release') {
cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release") cmake_build("hcc", "-DCMAKE_BUILD_TYPE=release")
} }
......
pfultz2/rocm-recipes pfultz2/rocm-recipes
danmar/cppcheck@8aa68ee297c2d9ebadf5bcfd00c66ea8d9291e35 -DHAVE_RULES=1 danmar/cppcheck@ef714225bb31e9a76ac2484796763572386955ae -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@2490e42baa7d90458f0632fd9fbead2d395f41b9 ROCm-Developer-Tools/HIP@2490e42baa7d90458f0632fd9fbead2d395f41b9
python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
...@@ -36,7 +36,7 @@ struct argument : raw_data<argument> ...@@ -36,7 +36,7 @@ struct argument : raw_data<argument>
} }
/// Provides a raw pointer to the data /// Provides a raw pointer to the data
std::function<char*()> data; std::function<char*()> data = nullptr;
/// Whether data is available /// Whether data is available
bool empty() const { return not data; } bool empty() const { return not data; }
......
...@@ -374,14 +374,14 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ...@@ -374,14 +374,14 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
} }
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins->outputs().front(); return ins->outputs().front();
return ctx.not_found(); return ctx.not_found();
} }
MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins; return ins;
...@@ -482,7 +482,7 @@ inline auto nargs(std::size_t n) ...@@ -482,7 +482,7 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i) inline auto arg(std::size_t i)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
if(i < ins->inputs().size()) if(i < ins->inputs().size())
return ins->inputs()[i]; return ins->inputs()[i];
return ctx.not_found(); return ctx.not_found();
......
...@@ -30,23 +30,29 @@ struct binary : op_name<Derived> ...@@ -30,23 +30,29 @@ struct binary : op_name<Derived>
argument result{output_shape}; argument result{output_shape};
auto s1 = args[0].get_shape(); auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape(); auto s2 = args[1].get_shape();
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { if(s1 == s2 and s1.packed())
if(s1 == s2 and input1.get_shape().packed() and input2.get_shape().packed()) {
{ shape std_shape{s1.type(), s1.lens()};
argument std_result{std_shape, result.data()};
argument std_arg0{std_shape, args[0].data()};
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
} });
else }
{ else
{
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()( output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end())); input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
}); });
} });
}); }
return result; return result;
} }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct capture
{
std::size_t ins_index;
std::function<void(std::size_t ins_index, std::vector<argument>)> f{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ins_index, "ins_index"));
}
std::string name() const { return "capture"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
{
if(f)
{
f(ins_index, args);
}
else
{
MIGRAPHX_THROW("CAPTURE: callback function is not callable!");
}
return args.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quant_convolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
}
std::string name() const { return "quant_convolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
// all input type must be int8_type and output is float_type
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t");
}
t = shape::int32_type;
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANT_DOT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quant_dot
{
int32_t alpha = 1;
int32_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(as_number(self.alpha), "alpha"), f(as_number(self.beta), "beta"));
}
std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type();
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(t != shape::int8_type)
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
}
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{
MIGRAPHX_THROW("QUANT_DOT: dot only accept 2 or more dims operands");
}
// only handle the case that the batch size of a and b are the same
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("QUANT_DOT: batch size of A and B mismatch: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{
MIGRAPHX_THROW("QUANT_DOT: inner dimensions do not match: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) + "}");
}
// k be multiple of 4
if((a.lens()[dim_1] % 4) != 0)
{
MIGRAPHX_THROW("QUANT_DOT: size of A {" + to_string_range(a.lens()) + "} and B {" +
to_string_range(b.lens()) + "} must be multiple of 4 for int8 type");
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -27,26 +27,34 @@ struct unary : op_name<Derived> ...@@ -27,26 +27,34 @@ struct unary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { auto in_shape = args[0].get_shape();
args[0].visit([&](auto input) { if(in_shape.packed())
if(input.get_shape().packed()) {
{ shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
argument arg_in{std_in_shape, args[0].data()};
argument arg_out{std_out_shape, result.data()};
arg_out.visit([&](auto output) {
arg_in.visit([&](auto input) {
std::transform(input.begin(), std::transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
return result;
}
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
}); });
return result;
}); });
}); }
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
});
});
}
return result; return result;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
...@@ -45,6 +46,8 @@ ...@@ -45,6 +46,8 @@
#include <migraphx/op/outline.hpp> #include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp> #include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_sum.hpp> #include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
......
...@@ -15,6 +15,15 @@ struct program; ...@@ -15,6 +15,15 @@ struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names); void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog); void quantize(program& prog);
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names);
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -168,6 +168,7 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n ...@@ -168,6 +168,7 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
// cppcheck-suppress uninitvar
if(out_error != nullptr) if(out_error != nullptr)
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
......
...@@ -1011,9 +1011,10 @@ struct onnx_parser ...@@ -1011,9 +1011,10 @@ struct onnx_parser
} }
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) { std::transform(vec_names.begin(),
return map_actv_funcs[fn]; vec_names.end(),
}); vec_actv_funcs.begin(),
[&](const auto& fn) { return map_actv_funcs[fn]; });
// To be added later // To be added later
float clip = 0.0; float clip = 0.0;
...@@ -1127,9 +1128,10 @@ struct onnx_parser ...@@ -1127,9 +1128,10 @@ struct onnx_parser
} }
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) { std::transform(vec_names.begin(),
return map_actv_funcs[name]; vec_names.end(),
}); vec_actv_funcs.begin(),
[&](const auto& name) { return map_actv_funcs[name]; });
float clip = 0.0; float clip = 0.0;
if(contains(attributes, "clip")) if(contains(attributes, "clip"))
...@@ -1299,9 +1301,10 @@ struct onnx_parser ...@@ -1299,9 +1301,10 @@ struct onnx_parser
} }
std::vector<operation> vec_actv_funcs(vec_names.size()); std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) { std::transform(vec_names.begin(),
return map_actv_funcs[name]; vec_names.end(),
}); vec_actv_funcs.begin(),
[&](const auto& name) { return map_actv_funcs[name]; });
float clip = 0.0; float clip = 0.0;
if(contains(attributes, "clip")) if(contains(attributes, "clip"))
......
...@@ -85,6 +85,9 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -85,6 +85,9 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
offset += (element_size - (offset % element_size)); offset += (element_size - (offset % element_size));
conflict_queue.pop(); conflict_queue.pop();
} }
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset = (offset + 3) / 4 * 4;
segment.offset = offset; segment.offset = offset;
MIGRAPHX_DEBUG(segment.dump()); MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size); required_bytes = std::max(required_bytes, offset + segment.size);
......
...@@ -107,7 +107,7 @@ struct memory_coloring_impl ...@@ -107,7 +107,7 @@ struct memory_coloring_impl
return ins->name() == "check_context"; return ins->name() == "check_context";
} }
static bool is_disjoin(live_range& range1, live_range& range2) static bool is_disjoin(const live_range& range1, const live_range& range2)
{ {
if((range1.size == 0) || (range2.size == 0)) if((range1.size == 0) || (range2.size == 0))
return false; return false;
......
...@@ -241,7 +241,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_ ...@@ -241,7 +241,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_
// TODO: Check every element // TODO: Check every element
assert(has_instruction(first)); assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); }); std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](instruction& ins) { return ins.outputs().empty(); })); assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last); return impl->instructions.erase(first, last);
} }
......
...@@ -156,6 +156,7 @@ PYBIND11_MODULE(migraphx, m) ...@@ -156,6 +156,7 @@ PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape) .def("get_shape", &migraphx::program::get_shape)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); }) .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
......
...@@ -3,32 +3,53 @@ ...@@ -3,32 +3,53 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <utility> #include <utility>
#include <iomanip>
#include <fstream>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_fp16(program& prog, instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_fp16) std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
if(map_fp16.count(ins) > 0) if(map_ins.count(ins) > 0)
{ {
return map_fp16[ins]; return map_ins[ins];
}
if(ins->name() == "undefined")
{
return ins;
} }
assert(ins->get_shape().type() == shape::float_type || assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type); ins->get_shape().type() == shape::double_type ||
instruction_ref ins_fp16{}; ins->get_shape().type() == shape::int32_type);
ins_fp16 = prog.insert_instruction(std::next(ins), op::convert{type}, ins); instruction_ref quant_ins{};
map_fp16[ins] = ins_fp16; quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
map_ins[ins] = quant_ins;
return ins_fp16; return quant_ins;
} }
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it
// is very rare in the area of deeping learning, so we just do a
// truncate of the input to get the fp16.
void quantize(program& prog, const std::vector<std::string>& ins_names) void quantize(program& prog, const std::vector<std::string>& ins_names)
{ {
std::unordered_map<instruction_ref, instruction_ref> map_fp16; std::unordered_map<instruction_ref, instruction_ref> map_fp16;
...@@ -59,7 +80,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -59,7 +80,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
} }
else else
{ {
input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16); input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16);
} }
converted_inputs.push_back(input_fp16); converted_inputs.push_back(input_fp16);
} }
...@@ -79,21 +100,13 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -79,21 +100,13 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
auto ins_shape = compute_shape(op, converted_inputs); auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type) if(ins_shape.type() != orig_type)
{ {
// insert another convert instruction to convert it back // check the dead code case to avoid assert
if(ins == std::prev(prog.end())) bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{ {
prog.add_instruction(op::convert{orig_type}, ins); prog.replace_instruction(ins, ins_orig_type);
}
else
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{
prog.replace_instruction(ins, ins_orig_type);
}
} }
} }
...@@ -103,5 +116,91 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -103,5 +116,91 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); } void quantize(program& prog) { quantize(prog, {"all"}); }
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name) != op_names.end();
}))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = prog.insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
}
return num_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals](
std::size_t ins_index, std::vector<migraphx::argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
param_pair.first = 127.0f / max_abs_vals->at(ins_index);
int8_quant_params->at(ins_index) = param_pair;
};
auto num_params = capture_arguments(prog, ins_names, calc_quant_params);
int8_quant_params->resize(num_params, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(num_params, 0.0f);
return int8_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
return capture_arguments(prog, ins_names);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -674,7 +674,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -674,7 +674,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
std::vector<float> ihc_data(ihc_shape.elements(), 0.0); std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
migraphx::shape pph_shape{type, {1, 3 * hidden_size}}; migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
std::vector<float> pph_data(pph_shape.elements(), 0.0);
auto actv_funcs = lstm_actv_funcs(ins); auto actv_funcs = lstm_actv_funcs(ins);
auto lstm_op = any_cast<op::lstm>(ins->get_operator()); auto lstm_op = any_cast<op::lstm>(ins->get_operator());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment