Commit 6b26286c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

new branch for quantization.

parent 900bad8b
...@@ -18,6 +18,7 @@ add_library(migraphx ...@@ -18,6 +18,7 @@ add_library(migraphx
generate.cpp generate.cpp
instruction.cpp instruction.cpp
program.cpp program.cpp
quantization.cpp
shape.cpp shape.cpp
schedule.cpp schedule.cpp
simplify_algebra.cpp simplify_algebra.cpp
......
...@@ -27,6 +27,7 @@ struct instruction ...@@ -27,6 +27,7 @@ struct instruction
void replace(const shape& r); void replace(const shape& r);
void recompute_shape(); void recompute_shape();
void recompute_ins_shape();
void clear_arguments(); void clear_arguments();
...@@ -67,7 +68,8 @@ struct instruction ...@@ -67,7 +68,8 @@ struct instruction
static void backreference(instruction_ref ref); static void backreference(instruction_ref ref);
static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins); static void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins,
bool recompute_shape = true);
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
......
...@@ -1364,6 +1364,45 @@ struct lstm_last_cell_output ...@@ -1364,6 +1364,45 @@ struct lstm_last_cell_output
} }
}; };
struct fp_conversion
{
bool reduce_precision = true;
std::string name() const { return "fp_conversion"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(reduce_precision)
{
if(inputs.front().type() != shape::float_type)
{
MIGRAPHX_THROW("FP_CONVERSION: input arguments must be type float");
}
return {shape::half_type, inputs.front().lens()};
}
else
{
if(inputs.front().type() != shape::half_type)
{
MIGRAPHX_THROW("FP_CONVERSION: input arguments must be type fp16");
}
return {shape::float_type, inputs.front().lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args.front().visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
return result;
}
};
struct undefined struct undefined
{ {
std::string name() const { return "undefined"; } std::string name() const { return "undefined"; }
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_QUANTIZATION_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void quantize(program& prog);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -30,6 +30,13 @@ void instruction::replace(const shape& r) ...@@ -30,6 +30,13 @@ void instruction::replace(const shape& r)
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); } void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::recompute_ins_shape()
{
auto r = compute_shape(op, arguments);
if(r != result)
result = r;
}
void instruction::clear_arguments() void instruction::clear_arguments()
{ {
for(auto&& arg : arguments) for(auto&& arg : arguments)
...@@ -126,11 +133,13 @@ void instruction::backreference(instruction_ref ref) ...@@ -126,11 +133,13 @@ void instruction::backreference(instruction_ref ref)
void instruction::replace_argument(instruction_ref ins, void instruction::replace_argument(instruction_ref ins,
instruction_ref old, instruction_ref old,
instruction_ref new_ins) instruction_ref new_ins,
bool recompute_shape)
{ {
ins->replace_argument(old, new_ins); ins->replace_argument(old, new_ins);
backreference(ins); backreference(ins);
ins->recompute_shape(); if (recompute_shape)
ins->recompute_shape();
} }
void instruction::replace(instruction_ref ins, void instruction::replace(instruction_ref ins,
......
...@@ -176,7 +176,7 @@ void memory_coloring_impl::build() ...@@ -176,7 +176,7 @@ void memory_coloring_impl::build()
void memory_coloring_impl::rewrite() void memory_coloring_impl::rewrite()
{ {
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
dims.push_back(required_bytes / sizeof(float)); dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims}; shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_program->add_parameter("scratch", s); instruction_ref scratch_param = p_program->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_program)) for(auto ins : iterator_for(*p_program))
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -181,6 +182,7 @@ PYBIND11_MODULE(migraphx, m) ...@@ -181,6 +182,7 @@ PYBIND11_MODULE(migraphx, m)
}); });
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize", &migraphx::quantize);
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
#include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
{
assert(ins->get_shape().type() == shape::float_type);
assert(ins->name().front() == '@');
instruction_ref ins_fp16{};
if(ins->name() == "@literal")
{
std::vector<float> values;
auto l_fp32 = ins->get_literal();
shape s = ins->get_shape();
l_fp32.visit([&](auto val) { values.assign(val.begin(), val.end()); });
ins_fp16 = prog.add_literal(literal({shape::half_type, s.lens()}, values));
}
else if(ins->name() == "@param")
{
if(ins == std::prev(prog.end()))
{
ins_fp16 = prog.add_instruction(op::fp_conversion{}, ins);
}
else
{
ins_fp16 = prog.insert_instruction(std::next(ins), op::fp_conversion{}, ins);
}
}
return ins_fp16;
}
void quantize(program& prog)
{
bool reduced_precision = false;
for(auto ins : iterator_for(prog))
{
// convert float_type to half_type
if(ins->name().front() == '@' && ins->get_shape().type() == shape::float_type)
{
auto ins_fp16 = convert_fp32_fp16(prog, ins);
auto outputs = ins->outputs();
for(auto output : outputs)
{
if(output != ins_fp16)
{
instruction::replace_argument(output, ins, ins_fp16, false);
}
}
reduced_precision = true;
}
}
// add another instruction at last to convert fp16 to fp32
if(reduced_precision)
{
for(auto ins : iterator_for(prog))
{
if(ins->name().front() != '@')
{
ins->recompute_ins_shape();
}
}
auto ins = std::prev(prog.end());
if(ins->get_shape().type() == shape::half_type)
{
prog.add_instruction(op::fp_conversion{false}, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -682,6 +682,18 @@ struct cpu_logsoftmax ...@@ -682,6 +682,18 @@ struct cpu_logsoftmax
} }
}; };
struct cpu_fp_conversion
{
op::fp_conversion op;
std::string name() const { return "cpu_fp_conversion"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
return op.compute(output_shape, std::move(args));
}
};
struct add_op struct add_op
{ {
std::string name() const { return "add"; } std::string name() const { return "add"; }
...@@ -792,6 +804,7 @@ struct cpu_apply ...@@ -792,6 +804,7 @@ struct cpu_apply
apply_map["pad"] = extend_op<cpu_pad, op::pad>(); apply_map["pad"] = extend_op<cpu_pad, op::pad>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["gather"] = extend_op<cpu_gather, op::gather>(); apply_map["gather"] = extend_op<cpu_gather, op::gather>();
apply_map["fp_conversion"] = extend_op<cpu_fp_conversion, op::fp_conversion>();
apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<cpu_logsoftmax, op::logsoftmax>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>(); apply_map["elu"] = extend_op<cpu_unary<elu_op>, op::elu>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment