Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fea96436
Commit
fea96436
authored
Jan 05, 2026
by
zhuwenwen
Browse files
update indexer_k_cache_kernel
parent
1af252cb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
15 deletions
+41
-15
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+39
-13
setup.py
setup.py
+2
-2
No files found.
csrc/cache_kernels.cu
View file @
fea96436
...
...
@@ -21,6 +21,7 @@
#include <cfloat> // FLT_MIN
#include <map>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
...
...
@@ -798,7 +799,19 @@ __global__ void indexer_k_cache_kernel(
const
int64_t
dst_offset
=
block_idx
*
cache_block_size
*
cache_stride
+
block_offset
*
head_dim
+
head_dim_idx
;
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
i
++
)
{
kv_cache
[
dst_offset
+
i
]
=
static_cast
<
cache_t
>
(
k_val_ptr
[
i
]);
float
val
=
static_cast
<
float
>
(
k_val_ptr
[
i
]);
if
constexpr
(
std
::
is_same
<
cache_t
,
at
::
Half
>::
value
||
std
::
is_same
<
cache_t
,
__half
>::
value
)
{
kv_cache
[
dst_offset
+
i
]
=
__float2half
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
cache_t
,
at
::
BFloat16
>::
value
||
std
::
is_same
<
cache_t
,
__nv_bfloat16
>::
value
)
{
kv_cache
[
dst_offset
+
i
]
=
__float2bfloat16
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
cache_t
,
float
>::
value
)
{
kv_cache
[
dst_offset
+
i
]
=
val
;
}
else
{
kv_cache
[
dst_offset
+
i
]
=
static_cast
<
cache_t
>
(
val
);
}
}
}
}
// namespace vllm
...
...
@@ -1625,16 +1638,29 @@ void indexer_k_cache(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
k
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
k
.
dtype
(),
"indexer_k_cache"
,
[
&
]
{
using
kv_t
=
scalar_t
;
using
cache_t
=
scalar_t
;
indexer_k_cache_kernel
<
kv_t
,
cache_t
>
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
k
.
scalar_type
(),
"indexer_k_cache_k"
,
([
&
]
{
using
k_t
=
scalar_t
;
if
(
kv_cache
.
scalar_type
()
==
at
::
ScalarType
::
Float
)
{
vllm
::
indexer_k_cache_kernel
<
k_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
kv_t
*>
(
k
.
data_ptr
(
)
),
reinterpret_cast
<
cache_t
*>
(
kv_cache
.
data_ptr
(
)
),
k
.
data_ptr
<
k_t
>
(),
kv_cache
.
data_ptr
<
float
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
head_dim
,
cache_block_size
,
cache_stride
);
});
}
else
if
(
kv_cache
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
vllm
::
indexer_k_cache_kernel
<
k_t
,
at
::
Half
>
<<<
grid
,
block
,
0
,
stream
>>>
(
k
.
data_ptr
<
k_t
>
(),
kv_cache
.
data_ptr
<
at
::
Half
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
head_dim
,
cache_block_size
,
cache_stride
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported kv_cache dtype: "
,
kv_cache
.
dtype
());
}
}));
}
setup.py
View file @
fea96436
...
...
@@ -509,9 +509,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
if
sha
!=
'Unknown'
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
version
=
'das.opt1.
alph
a.'
+
sha
[:
7
]
version
=
'das.opt1.
bet
a.'
+
sha
[:
7
]
else
:
version
=
'das.opt1.
alph
a'
version
=
'das.opt1.
bet
a'
# dtk version
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment