Commit faef98bf authored by Shucai Xiao's avatar Shucai Xiao
Browse files

reduce the rounding error in converting to int8

parent 41344324
...@@ -42,6 +42,7 @@ struct convert : unary<convert> ...@@ -42,6 +42,7 @@ struct convert : unary<convert>
float res = scale * x + shift; float res = scale * x + shift;
if(target_type == shape::int8_type) if(target_type == shape::int8_type)
{ {
res = res + 0.5f;
res = res > 127.0 ? 127.0 : res; res = res > 127.0 ? 127.0 : res;
res = res < -128.0 ? -128.0 : res; res = res < -128.0 ? -128.0 : res;
} }
......
#include <migraphx/gpu/device/convert.hpp> #ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <migraphx/gpu/device/nary.hpp> #define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace op {
namespace device {
struct convert : unary<convert>
void convert(hipStream_t stream,
const argument& result,
const argument& arg,
float scale,
float shift,
shape::type_t target_type)
{ {
result.visit([&](auto output) { shape::type_t target_type = shape::half_type;
arg.visit([&](auto input) { float scale = 1.0f;
const auto* input_ptr = device_cast(input.data()); float shift = 0.0f;
auto* output_ptr = device_cast(output.data());
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
}
auto apply() const
{
return [&](auto x) {
float res = scale * x + shift;
if(target_type == shape::int8_type) if(target_type == shape::int8_type)
{ {
gs_launch(stream, result.get_shape().elements())([=](auto i) { res = res + 0.5f;
output_ptr[i] = res = res > 127.0 ? 127.0 : res;
std::min<int8_t>(std::max<float>(-128, input_ptr[i] * scale + shift), 127); res = res < -128.0 ? -128.0 : res;
});
} }
else
{
gs_launch(stream, result.get_shape().elements())(
[=](auto i) { output_ptr[i] = input_ptr[i] * scale + shift; });
}
});
});
}
} // namespace device return res;
} // namespace gpu };
}
convert(shape::type_t t) : target_type{t} {}
convert(shape::type_t t, float sle, float sft) : target_type{t}, scale{sle}, shift{sft} {}
convert() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment