Commit 8afad0f6 authored by ltqin's avatar ltqin
Browse files

open mask

parent ac3c1563
......@@ -24,6 +24,7 @@ Kernel outputs:
*/
#define PRINT_HOST 0
#define USING_MASK 1
#include <iostream>
#include <numeric>
......@@ -69,8 +70,13 @@ static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
#else
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
#endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
......@@ -203,7 +209,8 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
#if 0
#if USING_MASK
auto N = s_g_m_n.GetLengths()[1];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment