Commit 9f6153b2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix code review comments

parent 5645b165
......@@ -82,8 +82,7 @@ std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed
{
std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
// divide a value to avoid integer overflow
std::transform(result.begin(), result.end(), result.begin(), [](auto i) { return i / 32; });
std::transform(result.begin(), result.end(), result.begin(), [](auto i) { return i; });
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; });
return result;
......
......@@ -20,14 +20,11 @@ namespace op {
struct convert : unary<convert>
{
shape::type_t target_type = shape::half_type;
float scale = 1.0f;
float shift = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
return pack(f(self.target_type, "target_type"));
}
shape compute_shape(std::vector<shape> inputs) const
......@@ -38,22 +35,10 @@ struct convert : unary<convert>
auto apply() const
{
return [&](auto x) {
float res = scale * x + shift;
if(target_type == shape::int8_type)
{
int factor = (res >= 0.0f) ? 1 : -1;
res = res + factor * 0.5f;
res = res > 127.0f ? 127.0f : res;
res = res < -128.0f ? -128.0f : res;
}
return res;
};
return [](auto x) { return x; };
}
convert(shape::type_t t) : target_type{t} {}
convert(shape::type_t t, float sle, float sft) : target_type{t}, scale{sle}, shift{sft} {}
convert() {}
};
......
......@@ -23,9 +23,7 @@ inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f,
float shift = 0.0f)
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
if(map_ins.count(ins) > 0)
{
......@@ -37,16 +35,11 @@ instruction_ref insert_quant_ins(program& prog,
return ins;
}
if(scale < 0.0f)
{
MIGRAPHX_THROW("INSERT_QUANT_INS: scale less than 0");
}
assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type ||
ins->get_shape().type() == shape::int32_type);
instruction_ref quant_ins{};
quant_ins = prog.insert_instruction(std::next(ins), op::convert{type, scale, shift}, ins);
quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
map_ins[ins] = quant_ins;
return quant_ins;
......
......@@ -15,7 +15,7 @@ shape hip_convert::compute_shape(std::vector<shape> inputs) const
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::convert(ctx.get_stream().get(), args[1], args[0], op.scale, op.shift, op.target_type);
device::convert(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
......
......@@ -6,31 +6,14 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void convert(hipStream_t stream,
const argument& result,
const argument& arg,
float scale,
float shift,
shape::type_t target_type)
void convert(hipStream_t stream, const argument& result, const argument& arg)
{
result.visit([&](auto output) {
arg.visit([&](auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
if(target_type == shape::int8_type)
{
gs_launch(stream, result.get_shape().elements())([=](auto i) {
float res = input_ptr[i] * scale + shift;
int factor = (res >= 0.0f) ? 1 : -1;
output_ptr[i] = static_cast<int8_t>(
std::min<float>(std::max<float>(-128.0f, res + factor * 0.5), 127.0f));
});
}
else
{
gs_launch(stream, result.get_shape().elements())(
[=](auto i) { output_ptr[i] = input_ptr[i] * scale + shift; });
}
gs_launch(stream,
result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i]; });
});
});
}
......
......@@ -11,12 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void convert(hipStream_t stream,
const argument& result,
const argument& arg,
float scale,
float shift,
shape::type_t target_type);
void convert(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
......
......@@ -3797,9 +3797,9 @@ struct test_convert : verify_program<test_convert>
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto ia =
p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type, 16.0f, 1.0f}, pa);
p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pa);
auto ib =
p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type, 16.0f, 2.0f}, pb);
p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pb);
p.add_instruction(migraphx::op::quant_dot{}, ia, ib);
return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment