Commit cbfbd04c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

extend the convert operator for int8 quantization

parent b643f202
...@@ -38,7 +38,16 @@ struct convert : unary<convert> ...@@ -38,7 +38,16 @@ struct convert : unary<convert>
auto apply() const auto apply() const
{ {
return [&](auto x) { return scale * x + shift; }; return [&](auto x) {
float res = scale * x + shift;
if (target_type == shape::int8_type)
{
res = res > 127.0 ? 127.0 : res;
res = res < -128.0 ? -128.0 : res;
}
return res;
};
} }
convert(shape::type_t t) : target_type{t} {} convert(shape::type_t t) : target_type{t} {}
......
...@@ -15,7 +15,7 @@ shape hip_convert::compute_shape(std::vector<shape> inputs) const ...@@ -15,7 +15,7 @@ shape hip_convert::compute_shape(std::vector<shape> inputs) const
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::convert(ctx.get_stream().get(), args[1], args[0], op.scale, op.shift); device::convert(ctx.get_stream().get(), args[1], args[0], op.scale, op.shift, op.target_type);
return args[1]; return args[1];
} }
......
...@@ -6,15 +6,30 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,15 +6,30 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void convert( void convert(hipStream_t stream,
hipStream_t stream, const argument& result, const argument& arg, float scale, float shift) const argument& result,
const argument& arg,
float scale,
float shift,
shape::type_t target_type)
{ {
result.visit([&](auto output) { result.visit([&](auto output) {
arg.visit([&](auto input) { arg.visit([&](auto input) {
const auto* input_ptr = device_cast(input.data()); const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
if(target_type == shape::int8_type)
{
gs_launch(stream, result.get_shape().elements())(
[=](auto i) {
output_ptr[i] = std::min<int8_t>(std::max<float>(-128, input_ptr[i]
* scale + shift), 127);
});
}
else
{
gs_launch(stream, result.get_shape().elements())( gs_launch(stream, result.get_shape().elements())(
[=](auto i) { output_ptr[i] = input_ptr[i] * scale + shift; }); [=](auto i) { output_ptr[i] = input_ptr[i] * scale + shift; });
}
}); });
}); });
} }
......
...@@ -12,7 +12,7 @@ namespace gpu { ...@@ -12,7 +12,7 @@ namespace gpu {
namespace device { namespace device {
void convert( void convert(
hipStream_t stream, const argument& result, const argument& arg, float scale, float shift); hipStream_t stream, const argument& result, const argument& arg, float scale, float shift, shape::type_t target_type);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment