Commit 8afad0f6 authored by ltqin's avatar ltqin
Browse files

open mask

parent ac3c1563
...@@ -24,6 +24,7 @@ Kernel outputs: ...@@ -24,6 +24,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 1
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -69,8 +70,13 @@ static constexpr ck::index_t NumDimK = 1; ...@@ -69,8 +70,13 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
#else
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
#endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
...@@ -203,12 +209,13 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -203,12 +209,13 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
#if 0 #if USING_MASK
const auto mask = DeviceGemmInstance::C0MatrixMask(N); auto N = s_g_m_n.GetLengths()[1];
s_g_m_n.ForEach([&](auto& self, auto idx) { const auto mask = DeviceGemmInstance::C0MatrixMask(N);
if(mask.IsMaskedElement(idx[1], idx[2])) s_g_m_n.ForEach([&](auto& self, auto idx) {
self(idx) = -ck::NumericLimits<float>::Infinity(); if(mask.IsMaskedElement(idx[1], idx[2]))
}); self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif #endif
// P = Softmax(S) // P = Softmax(S)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment