"megatron/git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "2a86fa207101c1c2f727fb9e04437b6b075e0788"
Commit faef98bf authored by Shucai Xiao's avatar Shucai Xiao
Browse files

reduce the rounding error in converting to int8

parent 41344324
......@@ -42,6 +42,7 @@ struct convert : unary<convert>
float res = scale * x + shift;
if(target_type == shape::int8_type)
{
res = res + 0.5f;
res = res > 127.0 ? 127.0 : res;
res = res < -128.0 ? -128.0 : res;
}
......
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/device/nary.hpp>
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void convert(hipStream_t stream,
const argument& result,
const argument& arg,
float scale,
float shift,
shape::type_t target_type)
namespace op {
struct convert : unary<convert>
{
result.visit([&](auto output) {
arg.visit([&](auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
if(target_type == shape::int8_type)
shape::type_t target_type = shape::half_type;
float scale = 1.0f;
float shift = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
}
shape compute_shape(std::vector<shape> inputs) const
{
gs_launch(stream, result.get_shape().elements())([=](auto i) {
output_ptr[i] =
std::min<int8_t>(std::max<float>(-128, input_ptr[i] * scale + shift), 127);
});
check_shapes{inputs, *this}.has(1);
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
}
else
auto apply() const
{
return [&](auto x) {
float res = scale * x + shift;
if(target_type == shape::int8_type)
{
gs_launch(stream, result.get_shape().elements())(
[=](auto i) { output_ptr[i] = input_ptr[i] * scale + shift; });
res = res + 0.5f;
res = res > 127.0 ? 127.0 : res;
res = res < -128.0 ? -128.0 : res;
}
});
});
}
} // namespace device
} // namespace gpu
return res;
};
}
convert(shape::type_t t) : target_type{t} {}
convert(shape::type_t t, float sle, float sft) : target_type{t}, scale{sle}, shift{sft} {}
convert() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment