"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "0ceeb3b4d94e1249aa300b8c4cae21abcab8b5a8"
Unverified Commit 9a8a5213 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Remove virtual destructors from unary ops (#1610)

* Remove virtual destructors from unary ops

* Fixes

* Fixes

* clang format fixes
parent 7d911154
...@@ -13,15 +13,17 @@ namespace ck { ...@@ -13,15 +13,17 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase struct UnaryOpBase
{ {
public: public:
__host__ __device__ virtual ~UnaryOpBase() = default; __host__ __device__ ~UnaryOpBase() = default;
__host__ __device__ UnaryOpBase() = default; __host__ __device__ constexpr UnaryOpBase() = default;
__host__ __device__ UnaryOpBase(const UnaryOpBase&) = default; __host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default; __host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
__host__ __device__ UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default; __host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;
__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0; __host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
...@@ -50,8 +52,14 @@ struct PassThroughPack2 ...@@ -50,8 +52,14 @@ struct PassThroughPack2
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
struct PassThrough : public UnaryOpBase struct PassThrough final : public UnaryOpBase
{ {
__host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
__host__ __device__ ~PassThrough() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; } __host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }
...@@ -409,8 +417,15 @@ struct UnarySquare ...@@ -409,8 +417,15 @@ struct UnarySquare
}; };
}; };
struct UnaryAbs : public UnaryOpBase struct UnaryAbs final : public UnaryOpBase
{ {
__host__ __device__ constexpr UnaryAbs() = default;
__host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default;
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
__host__ __device__ ~UnaryAbs() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final __host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
y = ck::math::abs(x); y = ck::math::abs(x);
...@@ -459,8 +474,15 @@ struct UnarySqrt ...@@ -459,8 +474,15 @@ struct UnarySqrt
}; };
}; };
struct Relu : public UnaryOpBase struct Relu final : public UnaryOpBase
{ {
__host__ __device__ constexpr Relu() = default;
__host__ __device__ constexpr Relu(const Relu&) = default;
__host__ __device__ constexpr Relu(Relu&&) = default;
__host__ __device__ Relu& operator=(const Relu&) = default;
__host__ __device__ Relu& operator=(Relu&&) = default;
__host__ __device__ ~Relu() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final __host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
y = x > 0 ? x : 0; y = x > 0 ? x : 0;
...@@ -633,8 +655,14 @@ struct Gelu ...@@ -633,8 +655,14 @@ struct Gelu
} }
}; };
struct Sigmoid : public UnaryOpBase struct Sigmoid final : public UnaryOpBase
{ {
__host__ __device__ constexpr Sigmoid() = default;
__host__ __device__ constexpr Sigmoid(const Sigmoid&) = default;
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
__host__ __device__ ~Sigmoid() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final __host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
...@@ -688,8 +716,15 @@ struct Silu ...@@ -688,8 +716,15 @@ struct Silu
}; };
}; };
struct TanH : public UnaryOpBase struct TanH final : public UnaryOpBase
{ {
__host__ __device__ constexpr TanH() = default;
__host__ __device__ constexpr TanH(const TanH&) = default;
__host__ __device__ constexpr TanH(TanH&&) = default;
__host__ __device__ TanH& operator=(const TanH&) = default;
__host__ __device__ TanH& operator=(TanH&&) = default;
__host__ __device__ ~TanH() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final __host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
y = ck::math::tanh(x); y = ck::math::tanh(x);
...@@ -959,8 +994,12 @@ struct Rcp ...@@ -959,8 +994,12 @@ struct Rcp
}; };
}; };
struct Swish : public UnaryOpBase struct Swish final : public UnaryOpBase
{ {
__host__ __device__ constexpr Swish(const Swish&) = default;
__host__ __device__ constexpr Swish(Swish&&) = default;
__host__ __device__ ~Swish() = default;
__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {} __host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ float get_beta() const { return beta_; } __host__ __device__ float get_beta() const { return beta_; }
...@@ -1019,8 +1058,12 @@ struct Swish : public UnaryOpBase ...@@ -1019,8 +1058,12 @@ struct Swish : public UnaryOpBase
} }
}; };
struct SoftRelu : public UnaryOpBase struct SoftRelu final : public UnaryOpBase
{ {
__host__ __device__ constexpr SoftRelu(const SoftRelu&) = default;
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
__host__ __device__ ~SoftRelu() = default;
__host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {} __host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; } __host__ __device__ float get_alpha() const { return alpha_; }
...@@ -1070,8 +1113,12 @@ struct SoftRelu : public UnaryOpBase ...@@ -1070,8 +1113,12 @@ struct SoftRelu : public UnaryOpBase
} }
}; };
struct Power : public UnaryOpBase struct Power final : public UnaryOpBase
{ {
__host__ __device__ constexpr Power(const Power&) = default;
__host__ __device__ constexpr Power(Power&&) = default;
__host__ __device__ ~Power() = default;
__host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) __host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma) : alpha_(alpha), beta_(beta), gamma_(gamma)
{ {
...@@ -1148,8 +1195,12 @@ struct Power : public UnaryOpBase ...@@ -1148,8 +1195,12 @@ struct Power : public UnaryOpBase
} }
}; };
struct ClippedRelu : public UnaryOpBase struct ClippedRelu final : public UnaryOpBase
{ {
__host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default;
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
__host__ __device__ ~ClippedRelu() = default;
__host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) __host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f)
: alpha_(alpha), beta_(beta) : alpha_(alpha), beta_(beta)
{ {
...@@ -1205,8 +1256,11 @@ struct ClippedRelu : public UnaryOpBase ...@@ -1205,8 +1256,11 @@ struct ClippedRelu : public UnaryOpBase
} }
}; };
struct LeakyRelu : public UnaryOpBase struct LeakyRelu final : public UnaryOpBase
{ {
__host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default;
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
__host__ __device__ ~LeakyRelu() = default;
__host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {} __host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}
...@@ -1250,8 +1304,11 @@ struct LeakyRelu : public UnaryOpBase ...@@ -1250,8 +1304,11 @@ struct LeakyRelu : public UnaryOpBase
} }
}; };
struct Elu : public UnaryOpBase struct Elu final : public UnaryOpBase
{ {
__host__ __device__ constexpr Elu(const Elu&) = default;
__host__ __device__ constexpr Elu(Elu&&) = default;
__host__ __device__ ~Elu() = default;
__host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {} __host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}
...@@ -1296,8 +1353,11 @@ struct Elu : public UnaryOpBase ...@@ -1296,8 +1353,11 @@ struct Elu : public UnaryOpBase
} }
}; };
struct Logistic : public UnaryOpBase struct Logistic final : public UnaryOpBase
{ {
__host__ __device__ constexpr Logistic(const Logistic&) = default;
__host__ __device__ constexpr Logistic(Logistic&&) = default;
__host__ __device__ ~Logistic() = default;
__host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {} __host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}
...@@ -1631,8 +1691,23 @@ struct DynamicUnaryOp ...@@ -1631,8 +1691,23 @@ struct DynamicUnaryOp
__host__ __device__ ~DynamicUnaryOp() __host__ __device__ ~DynamicUnaryOp()
{ {
if(unary_op_ptr_) switch(unary_op_type_)
delete unary_op_ptr_; {
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;
default: break;
}
} }
__device__ void InitUnaryOpPtrOnDevice() __device__ void InitUnaryOpPtrOnDevice()
...@@ -1721,6 +1796,7 @@ struct DynamicUnaryOp ...@@ -1721,6 +1796,7 @@ struct DynamicUnaryOp
float beta; float beta;
float gamma; float gamma;
}; };
#pragma clang diagnostic pop
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment