Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
899a2db4
Commit
899a2db4
authored
Feb 05, 2026
by
zhuwenwen
Browse files
sync v0.15.1(ex fused_moe&models)
parent
78c1f9e5
Changes
72
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
27 additions
and
50 deletions
+27
-50
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+0
-1
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+3
-10
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
+2
-3
vllm/v1/attention/ops/triton_decode_attention.py
vllm/v1/attention/ops/triton_decode_attention.py
+2
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+0
-1
vllm/v1/engine/detokenizer.py
vllm/v1/engine/detokenizer.py
+1
-1
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+0
-2
vllm/v1/structured_output/__init__.py
vllm/v1/structured_output/__init__.py
+3
-0
vllm/v1/worker/gpu/buffer_utils.py
vllm/v1/worker/gpu/buffer_utils.py
+0
-20
vllm/v1/worker/gpu/mm/encoder_runner.py
vllm/v1/worker/gpu/mm/encoder_runner.py
+4
-2
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+12
-4
vllm/v1/worker/worker_base.py
vllm/v1/worker/worker_base.py
+0
-6
No files found.
vllm/v1/attention/backends/mla/flashmla.py
View file @
899a2db4
...
...
@@ -43,7 +43,6 @@ from vllm.v1.attention.ops.flashmla import (
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm
import
envs
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
899a2db4
...
...
@@ -330,14 +330,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if
(
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
...
...
@@ -389,8 +382,8 @@ class RocmAttentionImpl(AttentionImpl):
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
]
if
key
is
not
None
else
None
,
value
=
value
[:
num_actual_tokens
]
if
value
is
not
None
else
None
,
key
=
key
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
output
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
key_cache
=
key_cache
,
...
...
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
View file @
899a2db4
...
...
@@ -302,9 +302,8 @@ def chunked_prefill_paged_decode(
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
len
(
seq_lens
)
num_query_heads
=
query
.
shape
[
1
]
# key may be None in cross-attention decode (already cached from encoder)
num_kv_heads
=
key
.
shape
[
1
]
if
key
is
not
None
else
key_cache
.
shape
[
1
]
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
num_kv_heads
=
key
.
shape
[
1
]
num_queries_per_kv
=
query
.
shape
[
1
]
//
key
.
shape
[
1
]
head_size
=
query
.
shape
[
2
]
# Conversion of FP8 Tensor from uint8 storage to
...
...
vllm/v1/attention/ops/triton_decode_attention.py
View file @
899a2db4
...
...
@@ -243,6 +243,7 @@ def _decode_att_m_fwd(
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
2
,
Lk
=
Lk
,
Lv
=
Lv
,
)
...
...
@@ -594,6 +595,7 @@ def _decode_softmax_reducev_fwd(
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
4
,
num_stages
=
2
,
**
extra_kargs
,
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
899a2db4
...
...
@@ -540,7 +540,6 @@ class Scheduler(SchedulerInterface):
break
request
=
self
.
waiting
.
peek_request
()
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
...
...
vllm/v1/engine/detokenizer.py
View file @
899a2db4
...
...
@@ -256,7 +256,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
class
SlowIncrementalDetokenizer
(
BaseIncrementalDetokenizer
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
request
:
EngineCoreRequest
,
mode
=
"auto"
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
request
:
EngineCoreRequest
):
super
().
__init__
(
request
)
self
.
tokenizer
=
tokenizer
...
...
vllm/v1/spec_decode/eagle.py
View file @
899a2db4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
from
dataclasses
import
replace
from
importlib.util
import
find_spec
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
vllm/v1/structured_output/__init__.py
View file @
899a2db4
...
...
@@ -74,6 +74,9 @@ class StructuredOutputManager:
self
.
tokenizer
=
cached_tokenizer_from_config
(
model_config
=
self
.
vllm_config
.
model_config
)
reasoning_parser
=
(
self
.
vllm_config
.
structured_outputs_config
.
reasoning_parser
)
reasoning_parser_plugin
=
(
self
.
vllm_config
.
structured_outputs_config
.
reasoning_parser_plugin
)
...
...
vllm/v1/worker/gpu/buffer_utils.py
View file @
899a2db4
...
...
@@ -11,26 +11,6 @@ from vllm.utils.platform_utils import is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
def
async_copy_to_gpu
(
x
:
torch
.
Tensor
|
np
.
ndarray
,
out
:
torch
.
Tensor
|
None
=
None
,
device
:
torch
.
device
|
None
=
None
,
)
->
torch
.
Tensor
:
if
isinstance
(
x
,
np
.
ndarray
):
x
=
torch
.
from_numpy
(
x
)
assert
x
.
is_cpu
assert
not
x
.
is_pinned
()
if
out
is
None
:
assert
device
is
not
None
out
=
torch
.
empty_like
(
x
,
device
=
device
)
# CPU-to-CPU copy
tmp
=
x
.
pin_memory
()
# CPU-to-GPU copy
return
out
.
copy_
(
tmp
,
non_blocking
=
True
)
class
UvaBuffer
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
):
if
not
is_uva_available
():
...
...
vllm/v1/worker/gpu/mm/encoder_runner.py
View file @
899a2db4
...
...
@@ -6,6 +6,7 @@ import torch
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalKwargsItem
from
vllm.multimodal.utils
import
group_mm_kwargs_by_modality
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.utils
import
sanity_check_mm_encoder_outputs
...
...
@@ -30,6 +31,7 @@ class EncoderRunner:
)
self
.
req_id_to_mm_features
:
dict
[
str
,
list
[
MultiModalFeatureSpec
]]
=
{}
self
.
encoder_cache
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
tmp_is_mm_embed
=
UvaBufferPool
(
max_num_tokens
,
torch
.
bool
)
def
add_request
(
self
,
req_id
:
str
,
mm_features
:
list
[
MultiModalFeatureSpec
]):
self
.
req_id_to_mm_features
[
req_id
]
=
mm_features
...
...
@@ -111,7 +113,7 @@ class EncoderRunner:
total_num_scheduled_tokens
,
dtype
=
torch
.
bool
,
device
=
"cpu"
,
pin_memory
=
Tru
e
,
pin_memory
=
Fals
e
,
)
for
i
,
req_id
in
enumerate
(
req_ids
):
if
not
is_prefilling
[
i
]:
...
...
@@ -160,7 +162,7 @@ class EncoderRunner:
mm_embeds
.
append
(
mm_embeds_item
)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed
=
is_mm_embed
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
is_mm_embed
=
self
.
tmp_
is_mm_embed
.
copy_to_gpu
(
is_mm_embed
)
return
mm_embeds
,
is_mm_embed
@
torch
.
inference_mode
()
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
899a2db4
...
...
@@ -30,7 +30,7 @@ from vllm.v1.worker.gpu.attn_utils import (
init_kv_cache
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.buffer_utils
import
async_copy_to_gpu
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.gpu.cudagraph_utils
import
CudaGraphManager
from
vllm.v1.worker.gpu.dp_utils
import
(
get_cudagraph_and_dp_padding
,
...
...
@@ -171,6 +171,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
# LoRA-related workers.
self
.
lora_state
=
LoraState
(
max_num_reqs
=
self
.
max_num_reqs
)
# Buffers for CPU-to-GPU copies.
self
.
tmp_idx_mapping
=
UvaBufferPool
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
tmp_cu_num_logits
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
self
.
tmp_query_start_loc
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
...
...
@@ -513,7 +518,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
req_states
.
req_id_to_index
[
req_id
]
for
req_id
in
req_ids
]
idx_mapping_np
=
np
.
array
(
idx_mapping_list
,
dtype
=
np
.
int32
)
idx_mapping
=
async_
copy_to_gpu
(
idx_mapping_np
,
device
=
self
.
device
)
idx_mapping
=
self
.
tmp_idx_mapping
.
copy_to_gpu
(
idx_mapping_np
)
# Get the number of draft tokens for each request.
if
not
scheduler_output
.
scheduled_spec_decode_tokens
:
...
...
@@ -541,7 +546,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_logits_np
=
np
.
empty
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
cu_num_logits_np
[
0
]
=
0
np
.
cumsum
(
num_logits
,
out
=
cu_num_logits_np
[
1
:])
cu_num_logits
=
async_
copy_to_gpu
(
cu_num_logits_np
,
device
=
self
.
device
)
cu_num_logits
=
self
.
tmp_cu_num_logits
.
copy_to_gpu
(
cu_num_logits_np
)
expanded_idx_mapping
=
expand_idx_mapping
(
idx_mapping
,
...
...
@@ -560,7 +565,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np
[
num_reqs
+
1
:]
=
num_tokens
async_copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
)
self
.
tmp_query_start_loc
.
copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
,
)
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
...
...
vllm/v1/worker/worker_base.py
View file @
899a2db4
...
...
@@ -78,12 +78,6 @@ class WorkerBase:
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
vllm_config
is
not
None
and
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
# adjust to take into account data parallelism
# offset the rank by the data parallel rank
rank
=
self
.
parallel_config
.
data_parallel_rank
*
self
.
parallel_config
.
world_size
+
rank
self
.
local_rank
=
rank
%
torch
.
cuda
.
device_count
()
# Device and model state
self
.
device
:
torch
.
device
|
None
=
None
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment