Commit d068b568 authored by zhuwenwen's avatar zhuwenwen
Browse files

update indexer_k_cache

parent e03b1b33
......@@ -1638,10 +1638,14 @@ void indexer_k_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
k.scalar_type(), "indexer_k_cache_k", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
k.scalar_type(), "indexer_k_cache", ([&] {
using k_t = scalar_t;
if (kv_cache.scalar_type() == at::ScalarType::Float) {
auto kv_cache_type = kv_cache.scalar_type();
if (kv_cache_type == at::ScalarType::Float) {
vllm::indexer_k_cache_kernel<k_t, float>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
......@@ -1650,7 +1654,7 @@ void indexer_k_cache(
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache.scalar_type() == at::ScalarType::Half) {
} else if (kv_cache_type == at::ScalarType::Half) {
vllm::indexer_k_cache_kernel<k_t, at::Half>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
......@@ -1659,8 +1663,17 @@ void indexer_k_cache(
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::BFloat16) {
vllm::indexer_k_cache_kernel<k_t, at::BFloat16>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::BFloat16>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else {
TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
}
}));
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment