Commit d8f5cf8e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parents 12eb2c83 6560d17b
...@@ -16,19 +16,28 @@ ...@@ -16,19 +16,28 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins) instruction_ref convert_to_fp16(program& prog, instruction_ref& ins)
{ {
assert(ins->get_shape().type() == shape::float_type || assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type); ins->get_shape().type() == shape::double_type);
assert(ins->name().front() == '@'); assert(contains({"@literal", "@param"}, ins->name()));
instruction_ref ins_fp16{}; instruction_ref ins_fp16{};
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
std::vector<float> values; shape s = ins->get_shape();
auto l_fp32 = ins->get_literal(); auto l = ins->get_literal();
shape s = ins->get_shape(); if(s.type() == shape::float_type)
l_fp32.visit([&](auto val) { values.assign(val.begin(), val.end()); }); {
ins_fp16 = prog.add_literal(literal({shape::half_type, s.lens()}, values)); auto tv = l.get<const float>();
ins_fp16 =
prog.add_literal(literal({shape::half_type, s.lens()}, tv.begin(), tv.end()));
}
else
{
auto tv = l.get<const double>();
ins_fp16 =
prog.add_literal(literal({shape::half_type, s.lens()}, tv.begin(), tv.end()));
}
} }
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
...@@ -47,13 +56,17 @@ instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins) ...@@ -47,13 +56,17 @@ instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
void quantize(program& prog) void quantize(program& prog)
{ {
bool reduced_precision = false; bool reduced_precision = false;
shape::type_t orig_type = shape::float_type;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
// convert float_type to half_type // convert float_type to half_type
if(ins->name().front() == '@' && ins->get_shape().type() == shape::float_type) if(contains({"@literal", "@param"}, ins->name()) &&
(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type))
{ {
auto ins_fp16 = convert_fp32_fp16(prog, ins); orig_type = ins->get_shape().type();
auto ins_fp16 = convert_to_fp16(prog, ins);
auto outputs = ins->outputs(); auto outputs = ins->outputs();
for(auto output : outputs) for(auto output : outputs)
{ {
...@@ -72,7 +85,7 @@ void quantize(program& prog) ...@@ -72,7 +85,7 @@ void quantize(program& prog)
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(ins->name().front() != '@') if(!contains({"@literal", "@param"}, ins->name()))
{ {
ins->recompute_ins_shape(); ins->recompute_ins_shape();
} }
...@@ -81,7 +94,7 @@ void quantize(program& prog) ...@@ -81,7 +94,7 @@ void quantize(program& prog)
auto ins = std::prev(prog.end()); auto ins = std::prev(prog.end());
if(ins->get_shape().type() == shape::half_type) if(ins->get_shape().type() == shape::half_type)
{ {
prog.add_instruction(op::fp_conversion{shape::float_type}, ins); prog.add_instruction(op::fp_conversion{orig_type}, ins);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment