Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14356d5d
Commit
14356d5d
authored
Mar 31, 2025
by
zhuwenwen
Browse files
update list and type
parent
52675626
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
11 deletions
+11
-11
vllm/_custom_ops.py
vllm/_custom_ops.py
+5
-5
vllm/sequence.py
vllm/sequence.py
+6
-6
No files found.
vllm/_custom_ops.py
View file @
14356d5d
...
@@ -956,7 +956,7 @@ def triton_int8_gemm_helper(m: int,
...
@@ -956,7 +956,7 @@ def triton_int8_gemm_helper(m: int,
per_token_act_quant
:
bool
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
use_bias
:
bool
,
out_dtype
:
T
ype
[
torch
.
dtype
]
=
torch
.
float16
,
out_dtype
:
t
ype
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
best_config
:
Optional
[
list
]
=
None
):
best_config
:
Optional
[
list
]
=
None
):
return
quant_tools
.
triton_int8_gemm_helper
(
m
,
n
,
k
,
per_token_act_quant
,
per_out_channel_weight_quant
,
use_bias
,
out_dtype
,
device
,
best_config
)
return
quant_tools
.
triton_int8_gemm_helper
(
m
,
n
,
k
,
per_token_act_quant
,
per_out_channel_weight_quant
,
use_bias
,
out_dtype
,
device
,
best_config
)
...
@@ -1749,8 +1749,8 @@ def free_shared_buffer(ptr: int) -> None:
...
@@ -1749,8 +1749,8 @@ def free_shared_buffer(ptr: int) -> None:
def
read_cache
(
def
read_cache
(
keys
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
L
ist
[
torch
.
Tensor
],
key_caches
:
l
ist
[
torch
.
Tensor
],
value_caches
:
L
ist
[
torch
.
Tensor
],
value_caches
:
l
ist
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
kv_cache_dtype
:
str
)
->
None
:
)
->
None
:
...
@@ -1761,8 +1761,8 @@ def read_cache(
...
@@ -1761,8 +1761,8 @@ def read_cache(
def
write_cache_multi_layers
(
def
write_cache_multi_layers
(
keys
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
key_caches
:
L
ist
[
torch
.
Tensor
],
key_caches
:
l
ist
[
torch
.
Tensor
],
value_caches
:
L
ist
[
torch
.
Tensor
],
value_caches
:
l
ist
[
torch
.
Tensor
],
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
kv_cache_dtype
:
str
)
->
None
:
)
->
None
:
...
...
vllm/sequence.py
View file @
14356d5d
...
@@ -188,7 +188,7 @@ class SequenceData(msgspec.Struct,
...
@@ -188,7 +188,7 @@ class SequenceData(msgspec.Struct,
@
staticmethod
@
staticmethod
def
from_prompt_token_counts
(
def
from_prompt_token_counts
(
*
token_counts
:
T
uple
[
int
,
int
])
->
"SequenceData"
:
*
token_counts
:
t
uple
[
int
,
int
])
->
"SequenceData"
:
"""
"""
Construct a :class:`SequenceData` instance by concatenating
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
prompt token sequences.
...
@@ -1334,9 +1334,9 @@ class Logits(msgspec.Struct, array_like=True,
...
@@ -1334,9 +1334,9 @@ class Logits(msgspec.Struct, array_like=True,
# all tokens, whereas for decode step, it use used for last accepted tokens.
# all tokens, whereas for decode step, it use used for last accepted tokens.
logits
:
torch
.
Tensor
logits
:
torch
.
Tensor
# The sequence group metadata list. Only needed for decode step.
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list
:
Optional
[
L
ist
[
SequenceGroupMetadata
]]
=
None
seq_group_metadata_list
:
Optional
[
l
ist
[
SequenceGroupMetadata
]]
=
None
_seq_ids
:
L
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_seq_ids
:
l
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
seq_group_metadata_list
is
not
None
:
if
self
.
seq_group_metadata_list
is
not
None
:
...
@@ -1344,12 +1344,12 @@ class Logits(msgspec.Struct, array_like=True,
...
@@ -1344,12 +1344,12 @@ class Logits(msgspec.Struct, array_like=True,
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
@
property
@
property
def
seq_ids
(
self
)
->
L
ist
[
int
]:
def
seq_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
_seq_ids
return
self
.
_seq_ids
def
update
(
self
,
def
update
(
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]):
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]):
"""Update hidden states from target model invocation. Only used for
"""Update hidden states from target model invocation. Only used for
decode steps"""
decode steps"""
assert
len
(
seq_group_metadata_list
)
==
len
(
logits
)
assert
len
(
seq_group_metadata_list
)
==
len
(
logits
)
...
@@ -1357,7 +1357,7 @@ class Logits(msgspec.Struct, array_like=True,
...
@@ -1357,7 +1357,7 @@ class Logits(msgspec.Struct, array_like=True,
self
.
logits
=
torch
.
cat
([
self
.
logits
,
logits
])
self
.
logits
=
torch
.
cat
([
self
.
logits
,
logits
])
def
prune
(
self
,
def
prune
(
self
,
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
])
->
None
:
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
])
->
None
:
"""Prune to provided list of sequence ids. Only used for decode steps.
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
"""
# Currently this prunes all seq_ids not present in
# Currently this prunes all seq_ids not present in
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment