Commit fea96436 authored by zhuwenwen's avatar zhuwenwen
Browse files

update indexer_k_cache_kernel

parent 1af252cb
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <cfloat> // FLT_MIN #include <cfloat> // FLT_MIN
#include <map> #include <map>
#include <vector> #include <vector>
#include <ATen/cuda/CUDAContext.h>
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
...@@ -798,7 +799,19 @@ __global__ void indexer_k_cache_kernel( ...@@ -798,7 +799,19 @@ __global__ void indexer_k_cache_kernel(
const int64_t dst_offset = block_idx * cache_block_size * cache_stride + const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx; block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) { for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] = static_cast<cache_t>(k_val_ptr[i]); float val = static_cast<float>(k_val_ptr[i]);
if constexpr (std::is_same<cache_t, at::Half>::value ||
std::is_same<cache_t, __half>::value) {
kv_cache[dst_offset + i] = __float2half(val);
} else if constexpr (std::is_same<cache_t, at::BFloat16>::value ||
std::is_same<cache_t, __nv_bfloat16>::value) {
kv_cache[dst_offset + i] = __float2bfloat16(val);
} else if constexpr (std::is_same<cache_t, float>::value) {
kv_cache[dst_offset + i] = val;
} else {
kv_cache[dst_offset + i] = static_cast<cache_t>(val);
}
} }
} }
} // namespace vllm } // namespace vllm
...@@ -1625,16 +1638,29 @@ void indexer_k_cache( ...@@ -1625,16 +1638,29 @@ void indexer_k_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(k.dtype(), "indexer_k_cache", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(
using kv_t = scalar_t; k.scalar_type(), "indexer_k_cache_k", ([&] {
using cache_t = scalar_t; using k_t = scalar_t;
indexer_k_cache_kernel<kv_t, cache_t> if (kv_cache.scalar_type() == at::ScalarType::Float) {
vllm::indexer_k_cache_kernel<k_t, float>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
reinterpret_cast<kv_t*>(k.data_ptr()), k.data_ptr<k_t>(),
reinterpret_cast<cache_t*>(kv_cache.data_ptr()), kv_cache.data_ptr<float>(),
slot_mapping.data_ptr<int64_t>(), slot_mapping.data_ptr<int64_t>(),
head_dim, head_dim,
cache_block_size, cache_block_size,
cache_stride); cache_stride);
}); } else if (kv_cache.scalar_type() == at::ScalarType::Half) {
vllm::indexer_k_cache_kernel<k_t, at::Half>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::Half>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else {
TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
}
}));
} }
...@@ -509,9 +509,9 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -509,9 +509,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha != 'Unknown': if sha != 'Unknown':
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
version = 'das.opt1.alpha.' + sha[:7] version = 'das.opt1.beta.' + sha[:7]
else: else:
version = 'das.opt1.alpha' version = 'das.opt1.beta'
# dtk version # dtk version
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment