Commit d95143d5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change to function call to replace_argument, and remove unnecessary code

parent 723d1f0b
...@@ -18,7 +18,6 @@ add_library(migraphx ...@@ -18,7 +18,6 @@ add_library(migraphx
generate.cpp generate.cpp
instruction.cpp instruction.cpp
program.cpp program.cpp
quantization.cpp
quantize_ins.cpp quantize_ins.cpp
shape.cpp shape.cpp
schedule.cpp schedule.cpp
......
...@@ -27,7 +27,6 @@ struct instruction ...@@ -27,7 +27,6 @@ struct instruction
void replace(const shape& r); void replace(const shape& r);
void recompute_shape(); void recompute_shape();
void recompute_ins_shape();
void clear_arguments(); void clear_arguments();
...@@ -70,8 +69,7 @@ struct instruction ...@@ -70,8 +69,7 @@ struct instruction
static void replace_argument(instruction_ref ins, static void replace_argument(instruction_ref ins,
instruction_ref old, instruction_ref old,
instruction_ref new_ins, instruction_ref new_ins);
bool recompute_shape = true);
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_QUANTIZATION_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void quantize(program& prog);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -30,13 +30,6 @@ void instruction::replace(const shape& r) ...@@ -30,13 +30,6 @@ void instruction::replace(const shape& r)
void instruction::recompute_shape() { replace(compute_shape(op, arguments)); } void instruction::recompute_shape() { replace(compute_shape(op, arguments)); }
void instruction::recompute_ins_shape()
{
auto r = compute_shape(op, arguments);
if(r != result)
result = r;
}
void instruction::clear_arguments() void instruction::clear_arguments()
{ {
for(auto&& arg : arguments) for(auto&& arg : arguments)
...@@ -133,12 +126,10 @@ void instruction::backreference(instruction_ref ref) ...@@ -133,12 +126,10 @@ void instruction::backreference(instruction_ref ref)
void instruction::replace_argument(instruction_ref ins, void instruction::replace_argument(instruction_ref ins,
instruction_ref old, instruction_ref old,
instruction_ref new_ins, instruction_ref new_ins)
bool recompute_shape)
{ {
ins->replace_argument(old, new_ins); ins->replace_argument(old, new_ins);
backreference(ins); backreference(ins);
if(recompute_shape)
ins->recompute_shape(); ins->recompute_shape();
} }
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_ins.hpp> #include <migraphx/quantize_ins.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
...@@ -183,7 +182,6 @@ PYBIND11_MODULE(migraphx, m) ...@@ -183,7 +182,6 @@ PYBIND11_MODULE(migraphx, m)
}); });
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize", &migraphx::quantize);
m.def("quantize_ins", &migraphx::quantize_ins); m.def("quantize_ins", &migraphx::quantize_ins);
#ifdef HAVE_GPU #ifdef HAVE_GPU
......
#include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/fp_conversion.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref convert_to_fp16(program& prog, instruction_ref& ins)
{
assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type);
assert(contains({"@literal", "@param"}, ins->name()));
instruction_ref ins_fp16{};
if(ins->name() == "@literal")
{
shape s = ins->get_shape();
auto l = ins->get_literal();
if(s.type() == shape::float_type)
{
auto tv = l.get<const float>();
ins_fp16 =
prog.add_literal(literal({shape::half_type, s.lens()}, tv.begin(), tv.end()));
}
else
{
auto tv = l.get<const double>();
ins_fp16 =
prog.add_literal(literal({shape::half_type, s.lens()}, tv.begin(), tv.end()));
}
}
else if(ins->name() == "@param")
{
if(ins == std::prev(prog.end()))
{
ins_fp16 = prog.add_instruction(op::fp_conversion{}, ins);
}
else
{
ins_fp16 = prog.insert_instruction(std::next(ins), op::fp_conversion{}, ins);
}
}
return ins_fp16;
}
void quantize(program& prog)
{
bool reduced_precision = false;
shape::type_t orig_type = shape::float_type;
for(auto ins : iterator_for(prog))
{
// convert float_type to half_type
if(contains({"@literal", "@param"}, ins->name()) &&
(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type))
{
orig_type = ins->get_shape().type();
auto ins_fp16 = convert_to_fp16(prog, ins);
auto outputs = ins->outputs();
for(auto output : outputs)
{
if(output != ins_fp16)
{
instruction::replace_argument(output, ins, ins_fp16, false);
}
}
reduced_precision = true;
}
}
// add another instruction at last to convert fp16 to fp32
if(reduced_precision)
{
for(auto ins : iterator_for(prog))
{
if(!contains({"@literal", "@param"}, ins->name()))
{
ins->recompute_ins_shape();
}
}
auto ins = std::prev(prog.end());
if(ins->get_shape().type() == shape::half_type)
{
prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -61,6 +61,7 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names) ...@@ -61,6 +61,7 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
// process all inputs, if input is a fp32 or fp64, convert it // process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a fp_conversion operator. // to a fp16 by adding a fp_conversion operator.
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::vector<instruction_ref> converted_inputs;
for(auto input : inputs) for(auto input : inputs)
{ {
auto s = input->get_shape(); auto s = input->get_shape();
...@@ -77,16 +78,24 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names) ...@@ -77,16 +78,24 @@ void quantize_ins(program& prog, const std::vector<std::string>& ins_names)
{ {
input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16); input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
} }
instruction::replace_argument(ins, input, input_fp16, false); //instruction::replace_argument(ins, input, input_fp16, false);
converted_inputs.push_back(input_fp16);
} }
else
{
converted_inputs.push_back(input);
}
}
if (inputs != converted_inputs)
{
auto op = ins->get_operator();
instruction::replace(ins, op, compute_shape(op, converted_inputs), converted_inputs);
} }
// recompute the output shape
ins->recompute_ins_shape();
// If output is not the original type, add another instruction if (ins->get_shape().type() != orig_type)
// to convert it back to the original type
if(ins->get_shape().type() != orig_type)
{ {
// insert another fp_conversion instruction to convert it back
if(ins == std::prev(prog.end())) if(ins == std::prev(prog.end()))
{ {
prog.add_instruction(op::fp_conversion{orig_type}, ins); prog.add_instruction(op::fp_conversion{orig_type}, ins);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment