Commit f1ed4c5e authored by ltqin's avatar ltqin
Browse files

add lib gemm softmax gemm files

parent 1e59eb3b
...@@ -389,6 +389,49 @@ struct UnaryTypeConvert<ck::bhalf_t, float> ...@@ -389,6 +389,49 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
} }
}; };
struct ScaleMask
{
ScaleMask(float scale, float mask_filter_value)
: scale_(scale), mask_filter_value_(mask_filter_value)
{
}
// scale, masked
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x, const X1& mask) const;
__host__ __device__ constexpr void
operator()(float& y, const float& x, const int32_t& mask) const
{
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + filter_value;
}
const float scale_;
const float mask_filter_value_;
};
struct ScaleBiasMask
{
ScaleBiasMask(float scale, float mask_filter_value)
: scale_(scale), mask_filter_value_(mask_filter_value)
{
}
// biased, masked
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void
operator()(Y& y, const X0& x, const X1& bias, const X2& mask) const;
template <>
__host__ __device__ constexpr void
operator()(float& y, const float& x, const F16& bias, const int32_t& mask) const
{
float filter_value = (mask == 1 ? 0.0f : mask_filter_value_);
y = scale_ * x + ck::type_convert<float>(bias) + filter_value;
}
const float scale_;
const float mask_filter_value_;
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -3,5 +3,6 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ...@@ -3,5 +3,6 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment