Unverified Commit 9a8a5213 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Remove virtual destructors from unary ops (#1610)

* Remove virtual destructors from unary ops

* Fixes

* Fixes

* clang format fixes
parent 7d911154
......@@ -13,15 +13,17 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase
{
public:
__host__ __device__ virtual ~UnaryOpBase() = default;
__host__ __device__ ~UnaryOpBase() = default;
__host__ __device__ UnaryOpBase() = default;
__host__ __device__ UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
__host__ __device__ UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;
__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
......@@ -50,8 +52,14 @@ struct PassThroughPack2
constexpr const static bool is_pack2_invocable = true;
};
struct PassThrough : public UnaryOpBase
struct PassThrough final : public UnaryOpBase
{
__host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
__host__ __device__ ~PassThrough() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }
......@@ -409,8 +417,15 @@ struct UnarySquare
};
};
struct UnaryAbs : public UnaryOpBase
struct UnaryAbs final : public UnaryOpBase
{
__host__ __device__ constexpr UnaryAbs() = default;
__host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default;
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
__host__ __device__ ~UnaryAbs() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = ck::math::abs(x);
......@@ -459,8 +474,15 @@ struct UnarySqrt
};
};
struct Relu : public UnaryOpBase
struct Relu final : public UnaryOpBase
{
__host__ __device__ constexpr Relu() = default;
__host__ __device__ constexpr Relu(const Relu&) = default;
__host__ __device__ constexpr Relu(Relu&&) = default;
__host__ __device__ Relu& operator=(const Relu&) = default;
__host__ __device__ Relu& operator=(Relu&&) = default;
__host__ __device__ ~Relu() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = x > 0 ? x : 0;
......@@ -633,8 +655,14 @@ struct Gelu
}
};
struct Sigmoid : public UnaryOpBase
struct Sigmoid final : public UnaryOpBase
{
__host__ __device__ constexpr Sigmoid() = default;
__host__ __device__ constexpr Sigmoid(const Sigmoid&) = default;
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
__host__ __device__ ~Sigmoid() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
......@@ -688,8 +716,15 @@ struct Silu
};
};
struct TanH : public UnaryOpBase
struct TanH final : public UnaryOpBase
{
__host__ __device__ constexpr TanH() = default;
__host__ __device__ constexpr TanH(const TanH&) = default;
__host__ __device__ constexpr TanH(TanH&&) = default;
__host__ __device__ TanH& operator=(const TanH&) = default;
__host__ __device__ TanH& operator=(TanH&&) = default;
__host__ __device__ ~TanH() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = ck::math::tanh(x);
......@@ -959,8 +994,12 @@ struct Rcp
};
};
struct Swish : public UnaryOpBase
struct Swish final : public UnaryOpBase
{
__host__ __device__ constexpr Swish(const Swish&) = default;
__host__ __device__ constexpr Swish(Swish&&) = default;
__host__ __device__ ~Swish() = default;
__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ float get_beta() const { return beta_; }
......@@ -1019,8 +1058,12 @@ struct Swish : public UnaryOpBase
}
};
struct SoftRelu : public UnaryOpBase
struct SoftRelu final : public UnaryOpBase
{
__host__ __device__ constexpr SoftRelu(const SoftRelu&) = default;
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
__host__ __device__ ~SoftRelu() = default;
__host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
......@@ -1070,8 +1113,12 @@ struct SoftRelu : public UnaryOpBase
}
};
struct Power : public UnaryOpBase
struct Power final : public UnaryOpBase
{
__host__ __device__ constexpr Power(const Power&) = default;
__host__ __device__ constexpr Power(Power&&) = default;
__host__ __device__ ~Power() = default;
__host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma)
{
......@@ -1148,8 +1195,12 @@ struct Power : public UnaryOpBase
}
};
struct ClippedRelu : public UnaryOpBase
struct ClippedRelu final : public UnaryOpBase
{
__host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default;
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
__host__ __device__ ~ClippedRelu() = default;
__host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f)
: alpha_(alpha), beta_(beta)
{
......@@ -1205,8 +1256,11 @@ struct ClippedRelu : public UnaryOpBase
}
};
struct LeakyRelu : public UnaryOpBase
struct LeakyRelu final : public UnaryOpBase
{
__host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default;
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
__host__ __device__ ~LeakyRelu() = default;
__host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}
......@@ -1250,8 +1304,11 @@ struct LeakyRelu : public UnaryOpBase
}
};
struct Elu : public UnaryOpBase
struct Elu final : public UnaryOpBase
{
__host__ __device__ constexpr Elu(const Elu&) = default;
__host__ __device__ constexpr Elu(Elu&&) = default;
__host__ __device__ ~Elu() = default;
__host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}
......@@ -1296,8 +1353,11 @@ struct Elu : public UnaryOpBase
}
};
struct Logistic : public UnaryOpBase
struct Logistic final : public UnaryOpBase
{
__host__ __device__ constexpr Logistic(const Logistic&) = default;
__host__ __device__ constexpr Logistic(Logistic&&) = default;
__host__ __device__ ~Logistic() = default;
__host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}
......@@ -1631,8 +1691,23 @@ struct DynamicUnaryOp
__host__ __device__ ~DynamicUnaryOp()
{
if(unary_op_ptr_)
delete unary_op_ptr_;
switch(unary_op_type_)
{
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;
default: break;
}
}
__device__ void InitUnaryOpPtrOnDevice()
......@@ -1721,6 +1796,7 @@ struct DynamicUnaryOp
float beta;
float gamma;
};
#pragma clang diagnostic pop
} // namespace element_wise
} // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment