Commit 852a517a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 34493a8d
...@@ -11,10 +11,11 @@ namespace migraphx { ...@@ -11,10 +11,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_quant_ins(program& prog, instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f, float shift = 0.0f) float scale = 1.0f,
float shift = 0.0f)
{ {
if(map_ins.count(ins) > 0) if(map_ins.count(ins) > 0)
{ {
...@@ -25,7 +26,7 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -25,7 +26,7 @@ instruction_ref insert_quant_ins(program& prog,
ins->get_shape().type() == shape::double_type || ins->get_shape().type() == shape::double_type ||
ins->get_shape().type() == shape::int32_type); ins->get_shape().type() == shape::int32_type);
instruction_ref quant_ins{}; instruction_ref quant_ins{};
quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins); quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
map_ins[ins] = quant_ins; map_ins[ins] = quant_ins;
return quant_ins; return quant_ins;
...@@ -34,7 +35,7 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -34,7 +35,7 @@ instruction_ref insert_quant_ins(program& prog,
// This function is to convert any instructions specified in the input // This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator. // from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it // For the conversion, there could be cases of overflowing, but it
// is very rare in the area of deeping learning, so we just do a // is very rare in the area of deeping learning, so we just do a
// truncate of the input to get the fp16. // truncate of the input to get the fp16.
void quantize(program& prog, const std::vector<std::string>& ins_names) void quantize(program& prog, const std::vector<std::string>& ins_names)
{ {
...@@ -103,16 +104,16 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -103,16 +104,16 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); } void quantize(program& prog) { quantize(prog, {"all"}); }
// int8 quantization is different from fp16 since int8 can only handle value // int8 quantization is different from fp16 since int8 can only handle value
// -128 ~ 127. To convert the float or double to int8, we need a scale and // -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift. // a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now. // To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8(program& prog, const std::vector<std::string>& ins_names) void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
{ {
// For now, we only support the int8 quantization of gemm and convolution // For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"}; std::vector<std::string> op_names = {"dot", "convolution"};
if (!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) { if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name); return std::find(op_names.begin(), op_names.end(), name);
})) }))
{ {
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
} }
...@@ -135,24 +136,25 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -135,24 +136,25 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
std::vector<instruction_ref> converted_inputs; std::vector<instruction_ref> converted_inputs;
// process all inputs, if input is a fp32 or fp64, convert it // process all inputs, if input is a fp32 or fp64, convert it
// to a int8 type by adding a convert operator and replace // to a int8 type by adding a convert operator and replace
// the operator with the corresponding int8 version // the operator with the corresponding int8 version
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::size_t param_index = 0; std::size_t param_index = 0;
for(auto input : inputs) for(auto input : inputs)
{ {
// In general, the target_type is int8, but for the dot // In general, the target_type is int8, but for the dot
// operation, if it has 3 inputs, then the last one should // operation, if it has 3 inputs, then the last one should
// be converted to int32_type // be converted to int32_type
shape::type_t quant_type = shape::int8_type; shape::type_t quant_type = shape::int8_type;
if (ins->name() == "dot" and inputs.size() == 3 and input == inputs.back()) if(ins->name() == "dot" and inputs.size() == 3 and input == inputs.back())
{ {
quant_type = shape::int32_type; quant_type = shape::int32_type;
} }
auto param = int8_param[param_index++]; auto param = int8_param[param_index++];
auto s = input->get_shape(); auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type || s.type() == shape::int32_type) if(s.type() == shape::float_type || s.type() == shape::double_type ||
s.type() == shape::int32_type)
{ {
// if the input is a convert operator, uses its input // if the input is a convert operator, uses its input
// as its current input // as its current input
...@@ -160,19 +162,20 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -160,19 +162,20 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
if(input->name() == "convert") if(input->name() == "convert")
{ {
auto tmp_ins = input->inputs().front(); auto tmp_ins = input->inputs().front();
if (tmp_ins->get_shape().type() == quant_type) if(tmp_ins->get_shape().type() == quant_type)
{ {
quant_input = input->inputs().front(); quant_input = input->inputs().front();
} }
else else
{ {
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins, param.first, param.second); quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
} }
} }
else else
{ {
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins, param.first, param.second); quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
} }
converted_inputs.push_back(quant_input); converted_inputs.push_back(quant_input);
} }
...@@ -207,12 +210,10 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -207,12 +210,10 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
// used as scale and shift(.0f), which will generate results diffrent from // used as scale and shift(.0f), which will generate results diffrent from
// the original results. To adjust the output to be "correct(approximatly // the original results. To adjust the output to be "correct(approximatly
// equal)", we need additional calculation for that. // equal)", we need additional calculation for that.
prog.replace_instruction(ins, op, converted_inputs); prog.replace_instruction(ins, op, converted_inputs);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -13,9 +13,7 @@ shape hip_convert::compute_shape(std::vector<shape> inputs) const ...@@ -13,9 +13,7 @@ shape hip_convert::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument hip_convert::compute(context& ctx, argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
const shape&,
const std::vector<argument>& args) const
{ {
device::convert(ctx.get_stream().get(), args[1], args[0], op.scale, op.shift); device::convert(ctx.get_stream().get(), args[1], args[0], op.scale, op.shift);
return args[1]; return args[1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment