Commit 70becc77 authored by ltqin's avatar ltqin
Browse files

change mask from int32 to f16

parent 8f8c0ddc
......@@ -24,7 +24,7 @@ using B0DataType = ck::half_t;
using B1DataType = ck::half_t;
using CDataType = ck::half_t;
using D00DataType = ck::half_t;
using D01DataType = int32_t;
using D01DataType = ck::half_t;
using AccDataType = float;
struct SimpleDeviceMem
......
......@@ -399,12 +399,21 @@ struct ScaleMask
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x, const X1& mask) const;
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const int32_t& mask) const
operator()(float& y, const float& x, const int16_t& mask) const
{
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + filter_value;
}
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& mask) const
{
float filter_value = (mask < 1.0f ? mask_filter_value_ : 0.0f);
y = scale_ * x + filter_value;
}
const float scale_;
const float mask_filter_value_;
};
......@@ -423,12 +432,20 @@ struct ScaleBiasMask
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& bias, const int32_t& mask) const
operator()(float& y, const float& x, const half_t& bias, const int16_t& mask) const
{
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
}
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& bias, const half_t& mask) const
{
float filter_value = (mask < 1.0f ? mask_filter_value_ : 0.0f);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
}
const float scale_;
const float mask_filter_value_;
};
......
......@@ -28,7 +28,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16,
F16,
F16,
ck::Tuple<int32_t>,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
......@@ -48,7 +48,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16,
F16,
F16,
ck::Tuple<int32_t>,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
......@@ -68,7 +68,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16,
F16,
F16,
ck::Tuple<F16, int32_t>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
......@@ -88,7 +88,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16,
F16,
F16,
ck::Tuple<F16, int32_t>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
......
......@@ -80,7 +80,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16,
F16,
F16,
ck::Tuple<int32_t>,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
......@@ -100,7 +100,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
1,
F16,
F32,
ck::Tuple<int32_t>,
ck::Tuple<F16>,
ScaleMask,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
......@@ -117,7 +117,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16,
F16,
F16,
ck::Tuple<int32_t>,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
......@@ -137,7 +137,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
1,
F16,
F32,
ck::Tuple<int32_t>,
ck::Tuple<F16>,
ScaleMask,
MaskingSpecialization::MaskDisabled>{});
}
......@@ -154,7 +154,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16,
F16,
F16,
ck::Tuple<F16, int32_t>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
......@@ -174,7 +174,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
1,
F16,
F32,
ck::Tuple<F16, int32_t>,
ck::Tuple<F16, F16>,
ScaleBiasMask,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
......@@ -191,7 +191,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16,
F16,
F16,
ck::Tuple<F16, int32_t>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
......@@ -211,7 +211,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
1,
F16,
F32,
ck::Tuple<F16, int32_t>,
ck::Tuple<F16, F16>,
ScaleBiasMask,
MaskingSpecialization::MaskDisabled>{});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment