Commit 34493a8d authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 3003b4e3
...@@ -20,15 +20,14 @@ namespace op { ...@@ -20,15 +20,14 @@ namespace op {
struct convert : unary<convert> struct convert : unary<convert>
{ {
shape::type_t target_type = shape::half_type; shape::type_t target_type = shape::half_type;
float scale = 1.0f; float scale = 1.0f;
float shift = 0.0f; float shift = 0.0f;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.target_type, "target_type"), return pack(
f(self.scale, "scale"), f(self.target_type, "target_type"), f(self.scale, "scale"), f(self.shift, "shift"));
f(self.shift, "shift"));
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -39,7 +38,8 @@ struct convert : unary<convert> ...@@ -39,7 +38,8 @@ struct convert : unary<convert>
auto apply() const auto apply() const
{ {
// return [&](auto x) { return (target_type == shape::int8_type) ? static_cast<int8_t>(x * scale + shift) : x; }; // return [&](auto x) { return (target_type == shape::int8_type) ? static_cast<int8_t>(x *
// scale + shift) : x; };
return [&](auto x) { return scale * x + shift; }; return [&](auto x) { return scale * x + shift; };
} }
......
...@@ -6,14 +6,15 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,14 +6,15 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void convert(hipStream_t stream, const argument& result, const argument& arg, float scale, float shift) void convert(
hipStream_t stream, const argument& result, const argument& arg, float scale, float shift)
{ {
result.visit([&](auto output) { result.visit([&](auto output) {
arg.visit([&](auto input) { arg.visit([&](auto input) {
const auto* input_ptr = device_cast(input.data()); const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
gs_launch(stream, gs_launch(stream, result.get_shape().elements())(
result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i] * scale + shift; }); [=](auto i) { output_ptr[i] = input_ptr[i] * scale + shift; });
}); });
}); });
} }
......
...@@ -21,11 +21,10 @@ struct hip_convert ...@@ -21,11 +21,10 @@ struct hip_convert
} }
std::string name() const { return "gpu::convert"; } std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const; shape compute_shape(std::vector<shape> inputs) const;
argument argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
......
...@@ -11,7 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void convert(hipStream_t stream, const argument& result, const argument& arg, float scale, float shift); void convert(
hipStream_t stream, const argument& result, const argument& arg, float scale, float shift);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment