Commit 446204c7 authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

tests passing for single page k

parent a3e06cd5
......@@ -17,9 +17,89 @@
printf("\n[kin:end:%s]\n", #BOOL); \
}
__forceinline__ __device__
void print_qkv_params(const Qkv_params& params) {
// LLM generated
printf("Qkv_params:\n");
printf("q_ptr: %p\n", params.q_ptr);
printf("k_ptr: %p\n", params.k_ptr);
printf("v_ptr: %p\n", params.v_ptr);
printf("q_batch_stride: %" PRId64 "\n", params.q_batch_stride);
printf("k_batch_stride: %" PRId64 "\n", params.k_batch_stride);
printf("v_batch_stride: %" PRId64 "\n", params.v_batch_stride);
printf("q_row_stride: %" PRId64 "\n", params.q_row_stride);
printf("k_row_stride: %" PRId64 "\n", params.k_row_stride);
printf("v_row_stride: %" PRId64 "\n", params.v_row_stride);
printf("q_head_stride: %" PRId64 "\n", params.q_head_stride);
printf("k_head_stride: %" PRId64 "\n", params.k_head_stride);
printf("v_head_stride: %" PRId64 "\n", params.v_head_stride);
printf("h: %d\n", params.h);
printf("h_k: %d\n", params.h_k);
printf("h_h_k_ratio: %d\n", params.h_h_k_ratio);
}
__forceinline__ __device__
void print_flash_fwd_params(const Flash_fwd_params& params) {
print_qkv_params(params);
// LLM generated
printf("struct Flash_fwd_params:\n");
printf("o_ptr: %p\n", params.o_ptr);
printf("oaccum_ptr: %p\n", params.oaccum_ptr);
printf("o_batch_stride: %ld\n", params.o_batch_stride);
printf("o_row_stride: %ld\n", params.o_row_stride);
printf("o_head_stride: %ld\n", params.o_head_stride);
printf("p_ptr: %p\n", params.p_ptr);
printf("softmax_lse_ptr: %p\n", params.softmax_lse_ptr);
printf("softmax_lseaccum_ptr: %p\n", params.softmax_lseaccum_ptr);
printf("b: %d\n", params.b);
printf("seqlen_q: %d\n", params.seqlen_q);
printf("seqlen_k: %d\n", params.seqlen_k);
printf("seqlen_knew: %d\n", params.seqlen_knew);
printf("d: %d\n", params.d);
printf("seqlen_q_rounded: %d\n", params.seqlen_q_rounded);
printf("seqlen_k_rounded: %d\n", params.seqlen_k_rounded);
printf("d_rounded: %d\n", params.d_rounded);
printf("rotary_dim: %d\n", params.rotary_dim);
printf("scale_softmax: %f\n", params.scale_softmax);
printf("scale_softmax_log2: %f\n", params.scale_softmax_log2);
printf("cu_seqlens_q: %p\n", params.cu_seqlens_q);
printf("cu_seqlens_k: %p\n", params.cu_seqlens_k);
printf("seqused_k: %p\n", params.seqused_k);
printf("blockmask: %p\n", params.blockmask);
printf("knew_ptr: %p\n", params.knew_ptr);
printf("vnew_ptr: %p\n", params.vnew_ptr);
printf("knew_batch_stride: %ld\n", params.knew_batch_stride);
printf("vnew_batch_stride: %ld\n", params.vnew_batch_stride);
printf("knew_row_stride: %ld\n", params.knew_row_stride);
printf("vnew_row_stride: %ld\n", params.vnew_row_stride);
printf("knew_head_stride: %ld\n", params.knew_head_stride);
printf("vnew_head_stride: %ld\n", params.vnew_head_stride);
printf("rotary_cos_ptr: %p\n", params.rotary_cos_ptr);
printf("rotary_sin_ptr: %p\n", params.rotary_sin_ptr);
printf("cache_batch_idx: %p\n", params.cache_batch_idx);
printf("block_table: %p\n", params.block_table);
printf("block_table_batch_stride: %ld\n", params.block_table_batch_stride);
printf("page_block_size: %d\n", params.page_block_size);
printf("p_dropout: %f\n", params.p_dropout);
printf("p_dropout_in_uint8_t: %u\n", params.p_dropout_in_uint8_t);
printf("rp_dropout: %f\n", params.rp_dropout);
printf("scale_softmax_rp_dropout: %f\n", params.scale_softmax_rp_dropout);
printf("window_size_left: %d\n", params.window_size_left);
printf("window_size_right: %d\n", params.window_size_right);
printf("philox_args: %p\n", &(params.philox_args));
printf("rng_state: %p\n", params.rng_state);
printf("is_bf16: %d\n", params.is_bf16);
printf("is_causal: %d\n", params.is_causal);
printf("is_seqlens_k_cumulative: %d\n", params.is_seqlens_k_cumulative);
printf("is_rotary_interleaved: %d\n", params.is_rotary_interleaved);
printf("num_splits: %d\n", params.num_splits);
printf("alibi_slopes_ptr: %p\n", params.alibi_slopes_ptr);
printf("alibi_slopes_batch_stride: %ld\n", params.alibi_slopes_batch_stride);
}
template<typename Kernel_traits>
__forceinline__ __device__ void
print_traits() {
__forceinline__ __device__
void print_traits() {
// bool
printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem ? "true" : "false");
printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs ? "true" : "false");
......@@ -36,7 +116,8 @@ print_traits() {
printf("Kernel_traits::kSmemQSize : %d\n", Kernel_traits::kSmemQSize );
printf("Kernel_traits::kSmemKVSize : %d\n", Kernel_traits::kSmemKVSize );
printf("Kernel_traits::kSmemSize : %d\n", Kernel_traits::kSmemSize );
printf("Kernel_traits::kGmemRowsPerThread: %d\n", Kernel_traits::kGmemRowsPerThread );
printf("Kernel_traits::kGmemRowsPerThread: %d\n", Kernel_traits::kGmemRowsPerThread);
printf("Kernel_traits::kGmemThreadsPerRow: %d\n", Kernel_traits::kGmemThreadsPerRow);
printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad );
// cute object
......
......@@ -43,9 +43,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
#if 1
KIN_PRINT(print_traits<Kernel_traits>());
#endif
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
......@@ -60,9 +57,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
#if 1
KIN_PRINT(print_binfo(binfo))
#endif
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
......@@ -144,19 +138,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{});
#if 1
KIN_PRINT(print(sK.layout()))
KIN_PRINT(print(gK.layout()))
#endif
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
#if 1
KIN_PRINT(print(sV.layout()))
KIN_PRINT(print(sVt.layout()))
KIN_PRINT(print(sVtNoSwizzle.layout()))
#endif
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
......@@ -167,27 +152,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
#if 1
KIN_PRINT(print(tKgK.layout()))
KIN_PRINT(print(tKsK.layout()))
#endif
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT(print(tSrQ.layout()))
KIN_PRINT(print(tSrK.layout()))
#endif
Tensor tSgS = thr_mma.partition_C(gP);
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
#if 1
KIN_PRINT(print(acc_o.layout()))
#endif
//
// Copy Atom retiling
......@@ -195,22 +169,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
#if 0
KIN_PRINT(smem_thr_copy_Q.print_all());
#endif
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
#if 1
KIN_PRINT(print(tSsQ.layout()))
#endif
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
#if 1
KIN_PRINT(print(tSsK.layout()))
#endif
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
......@@ -227,10 +192,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT(print(cQ.layout()))
KIN_PRINT(print(cKV.layout()))
#endif
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) {
// print(tScQ.layout()); printf("\n");
......@@ -251,12 +212,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
#if 1
KIN_PRINT(print(tQcQ.layout()))
KIN_PRINT(print(tKVcKV.layout()))
KIN_PRINT(print(tQpQ.layout()))
KIN_PRINT(print(tKVpKV.layout()))
#endif
// Set predicates for k bounds
if (!Is_even_K) {
......@@ -538,13 +493,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
constexpr int kNWarps = Kernel_traits::kNWarps;
#if 1
KIN_PRINT(print_traits<Kernel_traits>())
KIN_PRINT_BOOL(Is_causal)
KIN_PRINT_BOOL(Is_local)
KIN_PRINT_BOOL(Has_alibi)
KIN_PRINT_BOOL(Is_even_MN)
KIN_PRINT_BOOL(Is_even_K)
KIN_PRINT_BOOL(Split)
KIN_PRINT_BOOL(Append_KV)
KIN_PRINT(print_flash_fwd_params(params))
#endif
using GmemTiledCopyO = std::conditional_t<
......@@ -558,9 +507,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
#if 1
KIN_PRINT(print_binfo(binfo))
#endif
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
const int n_block_min = !Is_local
......@@ -625,17 +571,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread
const index_t row_offset_k__shadow = block_table[(n_block_max - 1) * kBlockN / params.page_block_size] * params.k_batch_stride + (((n_block_max - 1) * kBlockN) % params.page_block_size) * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gK__shadow = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k__shadow),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
......@@ -646,13 +599,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
#if 1
KIN_PRINT(print(sK.layout()))
KIN_PRINT(print(gK.layout()))
KIN_PRINT(print(sV.layout()))
KIN_PRINT(print(sVt.layout()))
KIN_PRINT(print(sVtNoSwizzle.layout()))
#endif
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
......@@ -662,27 +608,31 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKgK__shadow = gmem_thr_copy_KV.partition_S(gK__shadow); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
if (block_table != nullptr) {
tKgK.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
#if 1
KIN_PRINT(print(tKgK.layout()))
KIN_PRINT(print(tKsK.layout()))
#endif
#if 1
fill(tVgV, 1.f * ((Element) tidx));
__syncthreads();
KIN_PRINT(print_tensor(gV))
KIN_PRINT([&]() {
for (int i = 0; i < n_block_max; i++) {
printf("%d ", block_table[i]);
}
}())
// if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (thread0()) {
// gK.data() = tKgK.data();
// }
KIN_PRINT(print_tensor(tKgK))
KIN_PRINT(print_tensor(gK))
KIN_PRINT(print_tensor(tKgK__shadow))
KIN_PRINT(print_tensor(gK__shadow))
#endif
typename Kernel_traits::TiledMma tiled_mma;
......@@ -690,15 +640,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT(print(tSrQ.layout()))
KIN_PRINT(print(tSrK.layout()))
#endif
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
#if 1
KIN_PRINT(print(acc_o.layout()))
#endif
//
// Copy Atom retiling
......@@ -707,16 +650,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
#if 1
KIN_PRINT(print(tSsQ.layout()))
#endif
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
#if 1
KIN_PRINT(print(tSsK.layout()))
#endif
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
......@@ -732,10 +669,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT(print(cQ.layout()))
KIN_PRINT(print(cKV.layout()))
#endif
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
......@@ -744,12 +677,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
#if 1
KIN_PRINT(print(tQcQ.layout()))
KIN_PRINT(print(tKVcKV.layout()))
KIN_PRINT(print(tQpQ.layout()))
KIN_PRINT(print(tKVpKV.layout()))
#endif
// Set predicates for k bounds
if (!Is_even_K) {
......
......@@ -4,6 +4,8 @@
#pragma once
#include "debug.h"
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
......@@ -298,16 +300,17 @@ template <typename Kernel_traits>
__forceinline__ __device__
int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
// base col of thread's slice relative to the block
const int col_offset = tidx % Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemElemsPerLoad;
// base row of thread's slice relative to the block
const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread;
// base col of thread's slice relative to the entire tensor
const int global_row_offset = block_row_offset + (n_block_max - 1) * Kernel_traits::kBlockN;
// base row of thread's slice relative to the page
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;
const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int page_offset = global_row_offset % page_block_size;
const int virtual_page_idx = global_row_offset / page_block_size;
KIN_PRINT(printf("%d", virtual_page_idx))
return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment