"docs/vscode:/vscode.git/clone" did not exist on "3ba36f97b89184357bfa2bf74f94ab18c6dea9d5"
Commit 70becc77 authored by ltqin's avatar ltqin
Browse files

change mask from int32 to f16

parent 8f8c0ddc
...@@ -24,7 +24,7 @@ using B0DataType = ck::half_t; ...@@ -24,7 +24,7 @@ using B0DataType = ck::half_t;
using B1DataType = ck::half_t; using B1DataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using D00DataType = ck::half_t; using D00DataType = ck::half_t;
using D01DataType = int32_t; using D01DataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
struct SimpleDeviceMem struct SimpleDeviceMem
......
...@@ -399,12 +399,21 @@ struct ScaleMask ...@@ -399,12 +399,21 @@ struct ScaleMask
template <typename Y, typename X0, typename X1> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x, const X1& mask) const; __host__ __device__ constexpr void operator()(Y& y, const X0& x, const X1& mask) const;
template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(float& y, const float& x, const int32_t& mask) const operator()(float& y, const float& x, const int16_t& mask) const
{ {
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_); float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + filter_value; y = scale_ * x + filter_value;
} }
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& mask) const
{
float filter_value = (mask < 1.0f ? mask_filter_value_ : 0.0f);
y = scale_ * x + filter_value;
}
const float scale_; const float scale_;
const float mask_filter_value_; const float mask_filter_value_;
}; };
...@@ -423,12 +432,20 @@ struct ScaleBiasMask ...@@ -423,12 +432,20 @@ struct ScaleBiasMask
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& bias, const int32_t& mask) const operator()(float& y, const float& x, const half_t& bias, const int16_t& mask) const
{ {
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_); float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value; y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
} }
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const half_t& bias, const half_t& mask) const
{
float filter_value = (mask < 1.0f ? mask_filter_value_ : 0.0f);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
}
const float scale_; const float scale_;
const float mask_filter_value_; const float mask_filter_value_;
}; };
......
...@@ -28,7 +28,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -28,7 +28,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<int32_t>, ck::Tuple<F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -48,7 +48,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -48,7 +48,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<int32_t>, ck::Tuple<F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -68,7 +68,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -68,7 +68,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<F16, int32_t>, ck::Tuple<F16, F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -88,7 +88,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -88,7 +88,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<F16, int32_t>, ck::Tuple<F16, F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
...@@ -80,7 +80,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -80,7 +80,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<int32_t>, ck::Tuple<F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -100,7 +100,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -100,7 +100,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
1, 1,
F16, F16,
F32, F32,
ck::Tuple<int32_t>, ck::Tuple<F16>,
ScaleMask, ScaleMask,
MaskingSpecialization::MaskOutUpperTriangle>{}); MaskingSpecialization::MaskOutUpperTriangle>{});
} }
...@@ -117,7 +117,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -117,7 +117,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<int32_t>, ck::Tuple<F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -137,7 +137,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -137,7 +137,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
1, 1,
F16, F16,
F32, F32,
ck::Tuple<int32_t>, ck::Tuple<F16>,
ScaleMask, ScaleMask,
MaskingSpecialization::MaskDisabled>{}); MaskingSpecialization::MaskDisabled>{});
} }
...@@ -154,7 +154,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -154,7 +154,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<F16, int32_t>, ck::Tuple<F16, F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -174,7 +174,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -174,7 +174,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
1, 1,
F16, F16,
F32, F32,
ck::Tuple<F16, int32_t>, ck::Tuple<F16, F16>,
ScaleBiasMask, ScaleBiasMask,
MaskingSpecialization::MaskOutUpperTriangle>{}); MaskingSpecialization::MaskOutUpperTriangle>{});
} }
...@@ -191,7 +191,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -191,7 +191,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
F16, F16,
F16, F16,
F16, F16,
ck::Tuple<F16, int32_t>, ck::Tuple<F16, F16>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -211,7 +211,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk ...@@ -211,7 +211,7 @@ void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk
1, 1,
F16, F16,
F32, F32,
ck::Tuple<F16, int32_t>, ck::Tuple<F16, F16>,
ScaleBiasMask, ScaleBiasMask,
MaskingSpecialization::MaskDisabled>{}); MaskingSpecialization::MaskDisabled>{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment