Commit d5f629e7 authored by ltqin's avatar ltqin
Browse files

fix example

parent 92b9b046
...@@ -447,8 +447,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -447,8 +447,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -268,7 +268,7 @@ int run(int argc, char* argv[]) ...@@ -268,7 +268,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 123; ck::index_t M = 253;
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
......
...@@ -102,8 +102,8 @@ static constexpr bool Deterministic = false; ...@@ -102,8 +102,8 @@ static constexpr bool Deterministic = false;
// If 32 < DIM <= 64 , ues prototype1 2nd template. // If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance = ck::tensor_operation::device::
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1< DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -172,8 +172,8 @@ using DeviceGemmInstance = ...@@ -172,8 +172,8 @@ using DeviceGemmInstance =
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance = ck::tensor_operation::device::
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1< DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -461,8 +461,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -461,8 +461,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -446,8 +446,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -446,8 +446,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -226,8 +226,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -226,8 +226,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -222,7 +222,7 @@ int run(int argc, char* argv[]) ...@@ -222,7 +222,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -304,7 +304,7 @@ int run(int argc, char* argv[]) ...@@ -304,7 +304,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -274,7 +274,7 @@ int run(int argc, char* argv[]) ...@@ -274,7 +274,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -369,7 +369,7 @@ int run(int argc, char* argv[]) ...@@ -369,7 +369,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
......
...@@ -52,6 +52,7 @@ struct MaskOutUpperTrianglePredicate ...@@ -52,6 +52,7 @@ struct MaskOutUpperTrianglePredicate
}; };
struct MaskUpperTringleFromBottonRightPredicate struct MaskUpperTringleFromBottonRightPredicate
{ {
MaskUpperTringleFromBottonRightPredicate() : offset_(0) {}
__host__ __device__ void SetOffset(const index_t offset) { offset_ = offset; } __host__ __device__ void SetOffset(const index_t offset) { offset_ = offset; }
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const __host__ __device__ constexpr bool operator()(index_t m, index_t n) const
{ {
...@@ -78,7 +79,8 @@ struct C0MatrixMask_impl ...@@ -78,7 +79,8 @@ struct C0MatrixMask_impl
if constexpr(std::is_same<MaskOutPredicate, if constexpr(std::is_same<MaskOutPredicate,
MaskUpperTringleFromBottonRightPredicate>::value) MaskUpperTringleFromBottonRightPredicate>::value)
{ {
predicate_.SetOffset(NRaw - MRaw); if(NRaw > MRaw)
predicate_.SetOffset(NRaw - MRaw);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment