"...composable_kernel_rocm.git" did not exist on "e4112de7303af1601c6590b964d0df2b6a7f7d32"
Commit 3c835d76 authored by ltqin's avatar ltqin
Browse files

move scalemask out of inter elementop

parent 05fc2f8e
...@@ -10,9 +10,29 @@ ...@@ -10,9 +10,29 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
struct ScaleMask
{
ScaleMask(float scale, float mask_filter_value)
: scale_(scale), mask_filter_value_(mask_filter_value)
{
}
// scale, masked
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x, const X1& mask) const;
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const ck::half_t& mask) const
{
float filter_value = (mask < 1.0f ? mask_filter_value_ : 0.0f);
y = scale_ * x + filter_value;
}
const float scale_;
const float mask_filter_value_;
};
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough; using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleMask; using Acc0ElementOp = ScaleMask;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough; using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
......
...@@ -389,35 +389,6 @@ struct UnaryTypeConvert<ck::bhalf_t, float> ...@@ -389,35 +389,6 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
} }
}; };
struct ScaleMask
{
ScaleMask(float scale, float mask_filter_value)
: scale_(scale), mask_filter_value_(mask_filter_value)
{
}
// scale, masked
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x, const X1& mask) const;
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const int16_t& mask) const
{
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + filter_value;
}
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& mask) const
{
float filter_value = (mask < 1.0f ? mask_filter_value_ : 0.0f);
y = scale_ * x + filter_value;
}
const float scale_;
const float mask_filter_value_;
};
struct ScaleBiasMask struct ScaleBiasMask
{ {
ScaleBiasMask(float scale, float mask_filter_value) ScaleBiasMask(float scale, float mask_filter_value)
......
...@@ -21,8 +21,6 @@ template <ck::index_t... Is> ...@@ -21,8 +21,6 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleMask = ck::tensor_operation::element_wise::ScaleMask;
using ScaleBiasMask = ck::tensor_operation::element_wise::ScaleBiasMask;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
......
...@@ -18,7 +18,6 @@ using F16 = ck::half_t; ...@@ -18,7 +18,6 @@ using F16 = ck::half_t;
using F32 = float; using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleMask = ck::tensor_operation::element_wise::ScaleMask;
using ScaleBiasMask = ck::tensor_operation::element_wise::ScaleBiasMask; using ScaleBiasMask = ck::tensor_operation::element_wise::ScaleBiasMask;
// f16 ScaleBiasMask masking // f16 ScaleBiasMask masking
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment