Commit d068b568 authored by zhuwenwen's avatar zhuwenwen
Browse files

update indexer_k_cache

parent e03b1b33
...@@ -1638,10 +1638,14 @@ void indexer_k_cache( ...@@ -1638,10 +1638,14 @@ void indexer_k_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND2(
k.scalar_type(), "indexer_k_cache_k", ([&] { at::ScalarType::Half,
at::ScalarType::BFloat16,
k.scalar_type(), "indexer_k_cache", ([&] {
using k_t = scalar_t; using k_t = scalar_t;
if (kv_cache.scalar_type() == at::ScalarType::Float) { auto kv_cache_type = kv_cache.scalar_type();
if (kv_cache_type == at::ScalarType::Float) {
vllm::indexer_k_cache_kernel<k_t, float> vllm::indexer_k_cache_kernel<k_t, float>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(), k.data_ptr<k_t>(),
...@@ -1650,7 +1654,7 @@ void indexer_k_cache( ...@@ -1650,7 +1654,7 @@ void indexer_k_cache(
head_dim, head_dim,
cache_block_size, cache_block_size,
cache_stride); cache_stride);
} else if (kv_cache.scalar_type() == at::ScalarType::Half) { } else if (kv_cache_type == at::ScalarType::Half) {
vllm::indexer_k_cache_kernel<k_t, at::Half> vllm::indexer_k_cache_kernel<k_t, at::Half>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(), k.data_ptr<k_t>(),
...@@ -1659,8 +1663,17 @@ void indexer_k_cache( ...@@ -1659,8 +1663,17 @@ void indexer_k_cache(
head_dim, head_dim,
cache_block_size, cache_block_size,
cache_stride); cache_stride);
} else if (kv_cache_type == at::ScalarType::BFloat16) {
vllm::indexer_k_cache_kernel<k_t, at::BFloat16>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::BFloat16>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else { } else {
TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype()); TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
} }
})); }));
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment