Commit d5f629e7 authored by ltqin's avatar ltqin
Browse files

fix example

parent 92b9b046
......@@ -447,8 +447,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -268,7 +268,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 123;
ck::index_t M = 253;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
......
......@@ -102,8 +102,8 @@ static constexpr bool Deterministic = false;
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
using DeviceGemmInstance = ck::tensor_operation::device::
DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -172,8 +172,8 @@ using DeviceGemmInstance =
MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
using DeviceGemmInstance = ck::tensor_operation::device::
DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -461,8 +461,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -446,8 +446,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -226,8 +226,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -222,7 +222,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -304,7 +304,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -274,7 +274,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -369,7 +369,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......
......@@ -52,6 +52,7 @@ struct MaskOutUpperTrianglePredicate
};
struct MaskUpperTringleFromBottonRightPredicate
{
MaskUpperTringleFromBottonRightPredicate() : offset_(0) {}
__host__ __device__ void SetOffset(const index_t offset) { offset_ = offset; }
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const
{
......@@ -78,6 +79,7 @@ struct C0MatrixMask_impl
if constexpr(std::is_same<MaskOutPredicate,
MaskUpperTringleFromBottonRightPredicate>::value)
{
if(NRaw > MRaw)
predicate_.SetOffset(NRaw - MRaw);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment