Commit 1af252cb authored by zhuwenwen's avatar zhuwenwen
Browse files

add bw gpt-oss-20b-BF16 tp1&2 moe (tn) configs

add indexer_k_cache_kernel
parent 90d4a822
...@@ -72,6 +72,13 @@ void indexer_k_quant_and_cache( ...@@ -72,6 +72,13 @@ void indexer_k_quant_and_cache(
int64_t quant_block_size, // quantization block size int64_t quant_block_size, // quantization block size
const std::string& scale_fmt); const std::string& scale_fmt);
// Indexer K cache function
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt);
void read_cache( void read_cache(
torch::Tensor& keys, torch::Tensor& keys,
torch::Tensor& values, torch::Tensor& values,
......
...@@ -768,6 +768,39 @@ __global__ void indexer_k_quant_and_cache_kernel( ...@@ -768,6 +768,39 @@ __global__ void indexer_k_quant_and_cache_kernel(
} }
} }
template <typename scalar_t, typename cache_t>
__global__ void indexer_k_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int cache_block_size, // cache block size
const int cache_stride // stride for each token in kv_cache
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] = static_cast<cache_t>(k_val_ptr[i]);
}
}
} // namespace vllm } // namespace vllm
// KV_T is the data type of key and value tensors. // KV_T is the data type of key and value tensors.
...@@ -1560,3 +1593,48 @@ void indexer_k_quant_and_cache( ...@@ -1560,3 +1593,48 @@ void indexer_k_quant_and_cache(
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
CALL_INDEXER_K_QUANT_AND_CACHE); CALL_INDEXER_K_QUANT_AND_CACHE);
} }
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_CACHE(KV_T, CACHE_T) \
vllm::indexer_k_cache_kernel<KV_T, CACHE_T> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, \
cache_block_size, cache_stride);
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + vec_size - 1) / vec_size);
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(k.dtype(), "indexer_k_cache", [&] {
using kv_t = scalar_t;
using cache_t = scalar_t;
indexer_k_cache_kernel<kv_t, cache_t>
<<<grid, block, 0, stream>>>(
reinterpret_cast<kv_t*>(k.data_ptr()),
reinterpret_cast<cache_t*>(kv_cache.data_ptr()),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
});
}
...@@ -926,6 +926,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -926,6 +926,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"int quant_block_size, str kv_cache_dtype) -> ()"); "int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache); &indexer_k_quant_and_cache);
cache_ops.def(
"indexer_k_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_cache", torch::kCUDA,
&indexer_k_cache);
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
......
...@@ -2176,6 +2176,12 @@ def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor, ...@@ -2176,6 +2176,12 @@ def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
quant_block_size, quant_block_size,
kv_cache_dtype) kv_cache_dtype)
def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str) -> None:
torch.ops._C_cache_ops.indexer_k_cache(k, kv_cache, slot_mapping,
kv_cache_dtype)
def get_device_attribute(attribute: int, device: int) -> int: def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
......
{
"triton_version": "3.1.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}
{
"triton_version": "3.1.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
}
}
...@@ -610,6 +610,13 @@ def sparse_attn_indexer( ...@@ -610,6 +610,13 @@ def sparse_attn_indexer(
quant_block_size, quant_block_size,
scale_fmt, scale_fmt,
) )
else:
ops.indexer_k_cache(
k,
kv_cache,
slot_mapping,
scale_fmt,
)
topk_indices_buffer[:hidden_states.shape[0]] = -1 topk_indices_buffer[:hidden_states.shape[0]] = -1
if has_prefill: if has_prefill:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment