Commit d8f5cf8e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parents 12eb2c83 6560d17b
......@@ -16,19 +16,28 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
instruction_ref convert_to_fp16(program& prog, instruction_ref& ins)
{
assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type);
assert(ins->name().front() == '@');
assert(contains({"@literal", "@param"}, ins->name()));
instruction_ref ins_fp16{};
if(ins->name() == "@literal")
{
std::vector<float> values;
auto l_fp32 = ins->get_literal();
shape s = ins->get_shape();
l_fp32.visit([&](auto val) { values.assign(val.begin(), val.end()); });
ins_fp16 = prog.add_literal(literal({shape::half_type, s.lens()}, values));
shape s = ins->get_shape();
auto l = ins->get_literal();
if(s.type() == shape::float_type)
{
auto tv = l.get<const float>();
ins_fp16 =
prog.add_literal(literal({shape::half_type, s.lens()}, tv.begin(), tv.end()));
}
else
{
auto tv = l.get<const double>();
ins_fp16 =
prog.add_literal(literal({shape::half_type, s.lens()}, tv.begin(), tv.end()));
}
}
else if(ins->name() == "@param")
{
......@@ -47,13 +56,17 @@ instruction_ref convert_fp32_fp16(program& prog, instruction_ref& ins)
void quantize(program& prog)
{
bool reduced_precision = false;
bool reduced_precision = false;
shape::type_t orig_type = shape::float_type;
for(auto ins : iterator_for(prog))
{
// convert float_type to half_type
if(ins->name().front() == '@' && ins->get_shape().type() == shape::float_type)
if(contains({"@literal", "@param"}, ins->name()) &&
(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type))
{
auto ins_fp16 = convert_fp32_fp16(prog, ins);
orig_type = ins->get_shape().type();
auto ins_fp16 = convert_to_fp16(prog, ins);
auto outputs = ins->outputs();
for(auto output : outputs)
{
......@@ -72,7 +85,7 @@ void quantize(program& prog)
{
for(auto ins : iterator_for(prog))
{
if(ins->name().front() != '@')
if(!contains({"@literal", "@param"}, ins->name()))
{
ins->recompute_ins_shape();
}
......@@ -81,7 +94,7 @@ void quantize(program& prog)
auto ins = std::prev(prog.end());
if(ins->get_shape().type() == shape::half_type)
{
prog.add_instruction(op::fp_conversion{shape::float_type}, ins);
prog.add_instruction(op::fp_conversion{orig_type}, ins);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment