Commit fea96436 authored by zhuwenwen's avatar zhuwenwen
Browse files

update indexer_k_cache_kernel

parent 1af252cb
......@@ -21,6 +21,7 @@
#include <cfloat> // FLT_MIN
#include <map>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
......@@ -798,7 +799,19 @@ __global__ void indexer_k_cache_kernel(
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] = static_cast<cache_t>(k_val_ptr[i]);
float val = static_cast<float>(k_val_ptr[i]);
if constexpr (std::is_same<cache_t, at::Half>::value ||
std::is_same<cache_t, __half>::value) {
kv_cache[dst_offset + i] = __float2half(val);
} else if constexpr (std::is_same<cache_t, at::BFloat16>::value ||
std::is_same<cache_t, __nv_bfloat16>::value) {
kv_cache[dst_offset + i] = __float2bfloat16(val);
} else if constexpr (std::is_same<cache_t, float>::value) {
kv_cache[dst_offset + i] = val;
} else {
kv_cache[dst_offset + i] = static_cast<cache_t>(val);
}
}
}
} // namespace vllm
......@@ -1625,16 +1638,29 @@ void indexer_k_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(k.dtype(), "indexer_k_cache", [&] {
using kv_t = scalar_t;
using cache_t = scalar_t;
indexer_k_cache_kernel<kv_t, cache_t>
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
k.scalar_type(), "indexer_k_cache_k", ([&] {
using k_t = scalar_t;
if (kv_cache.scalar_type() == at::ScalarType::Float) {
vllm::indexer_k_cache_kernel<k_t, float>
<<<grid, block, 0, stream>>>(
reinterpret_cast<kv_t*>(k.data_ptr()),
reinterpret_cast<cache_t*>(kv_cache.data_ptr()),
k.data_ptr<k_t>(),
kv_cache.data_ptr<float>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
});
} else if (kv_cache.scalar_type() == at::ScalarType::Half) {
vllm::indexer_k_cache_kernel<k_t, at::Half>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::Half>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else {
TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
}
}));
}
......@@ -509,9 +509,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
version = 'das.opt1.alpha.' + sha[:7]
version = 'das.opt1.beta.' + sha[:7]
else:
version = 'das.opt1.alpha'
version = 'das.opt1.beta'
# dtk version
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment