"vscode:/vscode.git/clone" did not exist on "7df79c86ddc4ebf36de94671b454485caf6cc395"
Commit 50601bf4 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Use int64_t for page pointer arth

parent f80aa0fd
...@@ -296,21 +296,21 @@ void cp_async_wait() { ...@@ -296,21 +296,21 @@ void cp_async_wait() {
// assumes that the tensor has already been positioned at the correct head. // assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits> template <typename Kernel_traits>
__forceinline__ __device__ __forceinline__ __device__
int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, int64_t resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) { const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kBlockN = Kernel_traits::kBlockN;
const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad; const int64_t col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; const int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; const int64_t global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int page_offset = global_row_offset % page_block_size; const int64_t page_offset = global_row_offset % page_block_size;
const int virtual_page_idx = global_row_offset / page_block_size; const int65_t virtual_page_idx = global_row_offset / page_block_size;
return block_table[virtual_page_idx] * page_stride return ((int64_t) block_table[virtual_page_idx]) * ((int64_t) page_stride)
+ page_offset * row_stride + page_offset * ((int64_t) row_stride)
+ col_offset; + col_offset;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment