Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d76fc11e
Commit
d76fc11e
authored
Jan 28, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev
parents
38166ec4
58996f35
Changes
313
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
248 additions
and
29 deletions
+248
-29
vllm/renderers/hf.py
vllm/renderers/hf.py
+70
-0
vllm/tool_parsers/kimi_k2_tool_parser.py
vllm/tool_parsers/kimi_k2_tool_parser.py
+2
-2
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+2
-1
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/kimi_k25.py
vllm/transformers_utils/configs/kimi_k25.py
+129
-0
vllm/transformers_utils/model_arch_config_convertor.py
vllm/transformers_utils/model_arch_config_convertor.py
+1
-0
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+10
-3
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
+3
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+2
-1
vllm/v1/structured_output/__init__.py
vllm/v1/structured_output/__init__.py
+0
-3
vllm/v1/worker/gpu/buffer_utils.py
vllm/v1/worker/gpu/buffer_utils.py
+20
-0
vllm/v1/worker/gpu/mm/encoder_runner.py
vllm/v1/worker/gpu/mm/encoder_runner.py
+2
-5
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+5
-12
No files found.
vllm/renderers/hf.py
View file @
d76fc11e
...
...
@@ -20,9 +20,11 @@ from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption
,
ChatTemplateResolutionError
,
ConversationMessage
,
build_video_prompts_from_mm_data
,
load_chat_template
,
parse_chat_messages
,
parse_chat_messages_async
,
rebuild_mm_uuids_from_mm_data
,
)
from
vllm.inputs
import
TextPrompt
,
TokensPrompt
from
vllm.logger
import
init_logger
...
...
@@ -547,6 +549,40 @@ class HfRenderer(RendererLike):
**
kwargs
,
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
# model which uses unified vision chunks for both images and videos.
if
(
getattr
(
model_config
.
hf_config
,
"use_unified_vision_chunk"
,
False
)
and
mm_uuids
is
not
None
and
mm_data
is
not
None
):
mm_uuids
=
rebuild_mm_uuids_from_mm_data
(
mm_uuids
,
mm_data
)
# get video placehoder, replace it with runtime video-chunk prompts
video_placeholder
=
getattr
(
model_config
.
hf_config
,
"video_placeholder"
,
None
)
if
video_placeholder
and
isinstance
(
prompt_raw
,
str
):
video_prompts
=
build_video_prompts_from_mm_data
(
mm_data
)
# replace in order
prompt_raw_parts
=
prompt_raw
.
split
(
video_placeholder
)
if
len
(
prompt_raw_parts
)
==
len
(
video_prompts
)
+
1
:
prompt_raw
=
""
.
join
(
[
prompt_raw_parts
[
i
]
+
video_prompts
[
i
]
for
i
in
range
(
len
(
video_prompts
))
]
)
prompt_raw
+=
prompt_raw_parts
[
-
1
]
else
:
logger
.
warning
(
"Number of video placeholders (%d) does not match "
"number of videos (%d) in the request."
,
len
(
prompt_raw_parts
)
-
1
,
len
(
video_prompts
),
)
prompt
=
(
TextPrompt
(
prompt
=
prompt_raw
)
if
isinstance
(
prompt_raw
,
str
)
...
...
@@ -587,6 +623,40 @@ class HfRenderer(RendererLike):
**
kwargs
,
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
# model which uses unified vision chunks for both images and videos.
if
(
getattr
(
model_config
.
hf_config
,
"use_unified_vision_chunk"
,
False
)
and
mm_uuids
is
not
None
and
mm_data
is
not
None
):
mm_uuids
=
rebuild_mm_uuids_from_mm_data
(
mm_uuids
,
mm_data
)
# get video placehoder, replace it with runtime video-chunk prompts
video_placeholder
=
getattr
(
model_config
.
hf_config
,
"video_placeholder"
,
None
)
if
video_placeholder
and
isinstance
(
prompt_raw
,
str
):
video_prompts
=
build_video_prompts_from_mm_data
(
mm_data
)
# replace in order
prompt_raw_parts
=
prompt_raw
.
split
(
video_placeholder
)
if
len
(
prompt_raw_parts
)
==
len
(
video_prompts
)
+
1
:
prompt_raw
=
""
.
join
(
[
prompt_raw_parts
[
i
]
+
video_prompts
[
i
]
for
i
in
range
(
len
(
video_prompts
))
]
)
prompt_raw
+=
prompt_raw_parts
[
-
1
]
else
:
logger
.
warning
(
"Number of video placeholders (%d) does not match "
"number of videos (%d) in the request."
,
len
(
prompt_raw_parts
)
-
1
,
len
(
video_prompts
),
)
prompt
=
(
TextPrompt
(
prompt
=
prompt_raw
)
if
isinstance
(
prompt_raw
,
str
)
...
...
vllm/tool_parsers/kimi_k2_tool_parser.py
View file @
d76fc11e
...
...
@@ -448,7 +448,7 @@ class KimiK2ToolParser(ToolParser):
if
current_tool_call_matches
:
tool_id
,
tool_args
=
current_tool_call_matches
.
groups
()
tool_name
=
tool_id
.
split
(
":"
)[
0
].
split
(
"."
)[
-
1
]
current_tool_call
[
"id"
]
=
tool_id
current_tool_call
[
"id"
]
=
tool_id
.
strip
()
current_tool_call
[
"name"
]
=
tool_name
current_tool_call
[
"arguments"
]
=
tool_args
else
:
...
...
@@ -458,7 +458,7 @@ class KimiK2ToolParser(ToolParser):
if
current_tool_call_name_matches
:
(
tool_id_str
,)
=
current_tool_call_name_matches
.
groups
()
tool_name
=
tool_id_str
.
split
(
":"
)[
0
].
split
(
"."
)[
-
1
]
current_tool_call
[
"id"
]
=
tool_id_str
current_tool_call
[
"id"
]
=
tool_id_str
.
strip
()
current_tool_call
[
"name"
]
=
tool_name
current_tool_call
[
"arguments"
]
=
""
else
:
...
...
vllm/transformers_utils/config.py
View file @
d76fc11e
...
...
@@ -81,6 +81,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
isaac
=
"IsaacConfig"
,
kimi_linear
=
"KimiLinearConfig"
,
kimi_vl
=
"KimiVLConfig"
,
kimi_k25
=
"KimiK25Config"
,
RefinedWeb
=
"RWConfig"
,
# For tiiuae/falcon-40b(-instruct)
RefinedWebModel
=
"RWConfig"
,
# For tiiuae/falcon-7b(-instruct)
jais
=
"JAISConfig"
,
...
...
@@ -328,7 +329,7 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
partial_rotary_factor
=
getattr_iter
(
config
,
names
,
None
,
warn
=
True
)
ompe
=
getattr
(
config
,
"original_max_position_embeddings"
,
None
)
if
Version
(
version
(
"transformers"
))
<
Version
(
"5.0.0
.dev0
"
):
if
Version
(
version
(
"transformers"
))
<
Version
(
"5.0.0"
):
# Transformers v4 installed, legacy config fields may be present
if
(
rope_scaling
:
=
getattr
(
config
,
"rope_scaling"
,
None
))
is
not
None
:
config
.
rope_parameters
=
rope_scaling
...
...
vllm/transformers_utils/configs/__init__.py
View file @
d76fc11e
...
...
@@ -39,6 +39,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"MoonViTConfig"
:
"vllm.transformers_utils.configs.moonvit"
,
"KimiLinearConfig"
:
"vllm.transformers_utils.configs.kimi_linear"
,
"KimiVLConfig"
:
"vllm.transformers_utils.configs.kimi_vl"
,
"KimiK25Config"
:
"vllm.transformers_utils.configs.kimi_k25"
,
"NemotronConfig"
:
"vllm.transformers_utils.configs.nemotron"
,
"NemotronHConfig"
:
"vllm.transformers_utils.configs.nemotron_h"
,
"Olmo3Config"
:
"vllm.transformers_utils.configs.olmo3"
,
...
...
@@ -78,6 +79,7 @@ __all__ = [
"MoonViTConfig"
,
"KimiLinearConfig"
,
"KimiVLConfig"
,
"KimiK25Config"
,
"NemotronConfig"
,
"NemotronHConfig"
,
"Olmo3Config"
,
...
...
vllm/transformers_utils/configs/kimi_k25.py
0 → 100644
View file @
d76fc11e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Kimi-K2.5 Model Configuration.
This configuration supports video-chunk as an internal modality type.
A video-chunk is the smallest independently processable unit of video.
"""
from
transformers
import
DeepseekV3Config
from
transformers.configuration_utils
import
PretrainedConfig
class
KimiK25VisionConfig
(
PretrainedConfig
):
model_type
=
"kimi_k25_vision"
def
__init__
(
self
,
# Vision Tower
patch_size
:
int
=
14
,
init_pos_emb_height
:
int
=
64
,
init_pos_emb_width
:
int
=
64
,
init_pos_emb_time
:
int
=
4
,
pos_emb_type
:
str
=
"divided_fixed"
,
num_attention_heads
:
int
=
16
,
num_hidden_layers
:
int
=
27
,
hidden_size
:
int
=
1152
,
intermediate_size
:
int
=
4304
,
merge_kernel_size
:
tuple
[
int
,
int
]
=
(
2
,
2
),
video_attn_type
:
str
=
"spatial_temporal"
,
merge_type
:
str
=
"sd2_tpool"
,
# MM Projector
mm_projector_type
:
str
=
"patchmerger"
,
mm_hidden_size
:
int
|
None
=
None
,
projector_hidden_act
:
str
=
"gelu"
,
projector_ln_eps
:
float
=
1e-5
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
# Vision Tower
self
.
patch_size
=
patch_size
self
.
init_pos_emb_height
=
init_pos_emb_height
self
.
init_pos_emb_width
=
init_pos_emb_width
self
.
init_pos_emb_time
=
init_pos_emb_time
self
.
pos_emb_type
=
pos_emb_type
self
.
num_attention_heads
=
num_attention_heads
self
.
num_hidden_layers
=
num_hidden_layers
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
merge_kernel_size
=
merge_kernel_size
self
.
video_attn_type
=
video_attn_type
self
.
merge_type
=
merge_type
# MM Projector
self
.
mm_projector_type
=
mm_projector_type
if
mm_hidden_size
is
not
None
:
self
.
mm_hidden_size
=
mm_hidden_size
else
:
self
.
mm_hidden_size
=
hidden_size
self
.
projector_hidden_act
=
projector_hidden_act
self
.
projector_ln_eps
=
projector_ln_eps
class
KimiK25Config
(
PretrainedConfig
):
"""Kimi-K2.5 model configuration.
Kimi-K2.5 extends Kimi-K2 with vision support using video-chunks.
A video-chunk consists of multiple consecutive frames
that are processed together with temporal pooling.
Args:
vision_config: Configuration for the vision tower and projector.
text_config: Configuration for the text model (DeepseekV3).
ignore_index: The ignore index for the loss function.
media_placeholder_token_id: The token ID for media placeholders.
pad_token_id: The token ID for padding.
"""
model_type
=
"kimi_k25"
def
__init__
(
self
,
vision_config
:
dict
|
KimiK25VisionConfig
|
None
=
None
,
text_config
:
dict
|
DeepseekV3Config
|
None
=
None
,
ignore_index
:
int
=
-
100
,
media_placeholder_token_id
:
int
=
163605
,
pad_token_id
:
int
=
0
,
use_unified_vision_chunk
:
bool
=
False
,
video_placeholder
:
str
=
"<|kimi_k25_video_placeholder|>"
,
**
kwargs
,
):
# Vision config
if
vision_config
is
None
:
vision_config
=
KimiK25VisionConfig
()
elif
isinstance
(
vision_config
,
dict
):
vision_config
=
KimiK25VisionConfig
(
**
vision_config
)
self
.
vision_config
:
KimiK25VisionConfig
=
vision_config
# Text config
if
text_config
is
None
:
text_config
=
DeepseekV3Config
()
elif
isinstance
(
text_config
,
dict
):
text_config
=
DeepseekV3Config
(
**
text_config
)
self
.
text_config
:
DeepseekV3Config
=
text_config
# Set mm_hidden_size to text hidden size if not explicitly set
if
self
.
vision_config
.
mm_hidden_size
==
self
.
vision_config
.
hidden_size
:
self
.
vision_config
.
mm_hidden_size
=
self
.
text_config
.
hidden_size
# Other config
self
.
ignore_index
=
ignore_index
self
.
media_placeholder_token_id
=
media_placeholder_token_id
self
.
use_unified_vision_chunk
=
use_unified_vision_chunk
self
.
video_placeholder
=
video_placeholder
# Propagate quantization config from text model
if
getattr
(
self
.
text_config
,
"quantization_config"
,
None
)
is
not
None
:
self
.
quantization_config
=
self
.
text_config
.
quantization_config
super
().
__init__
(
pad_token_id
=
pad_token_id
,
**
kwargs
)
@
property
def
hidden_size
(
self
)
->
int
:
"""Get hidden size from text config for compatibility."""
return
self
.
text_config
.
hidden_size
@
property
def
vocab_size
(
self
)
->
int
:
"""Get vocab size from text config for compatibility."""
return
self
.
text_config
.
vocab_size
vllm/transformers_utils/model_arch_config_convertor.py
View file @
d76fc11e
...
...
@@ -398,6 +398,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"qwen3_next_mtp"
:
Qwen3NextMTPModelArchConfigConvertor
,
"mimo_mtp"
:
MimoMTPModelArchConfigConvertor
,
"glm4_moe_mtp"
:
GLM4MoeMTPModelArchConfigConvertor
,
"glm_ocr_mtp"
:
GLM4MoeMTPModelArchConfigConvertor
,
"ernie_mtp"
:
ErnieMTPModelArchConfigConvertor
,
"pangu_ultra_moe_mtp"
:
PanguUltraMoeMTPModelArchConfigConvertor
,
"longcat_flash_mtp"
:
LongCatFlashMTPModelArchConfigConvertor
,
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
d76fc11e
...
...
@@ -330,7 +330,14 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
kv_sharing_target_layer_name
is
None
:
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if
(
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
...
...
@@ -382,8 +389,8 @@ class RocmAttentionImpl(AttentionImpl):
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
]
if
key
is
not
None
else
None
,
value
=
value
[:
num_actual_tokens
]
if
value
is
not
None
else
None
,
output
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
key_cache
=
key_cache
,
...
...
vllm/v1/attention/ops/chunked_prefill_paged_decode.py
View file @
d76fc11e
...
...
@@ -302,8 +302,9 @@ def chunked_prefill_paged_decode(
block_size
=
value_cache
.
shape
[
3
]
num_seqs
=
len
(
seq_lens
)
num_query_heads
=
query
.
shape
[
1
]
num_kv_heads
=
key
.
shape
[
1
]
num_queries_per_kv
=
query
.
shape
[
1
]
//
key
.
shape
[
1
]
# key may be None in cross-attention decode (already cached from encoder)
num_kv_heads
=
key
.
shape
[
1
]
if
key
is
not
None
else
key_cache
.
shape
[
1
]
num_queries_per_kv
=
num_query_heads
//
num_kv_heads
head_size
=
query
.
shape
[
2
]
# Conversion of FP8 Tensor from uint8 storage to
...
...
vllm/v1/spec_decode/eagle.py
View file @
d76fc11e
...
...
@@ -405,7 +405,7 @@ class SpecDecodeBaseProposer:
return
draft_token_ids
.
view
(
-
1
,
1
)
if
self
.
uses_mrope
:
positions
=
self
.
positions
[:,
last_token_indices
]
positions
=
self
.
mrope_
positions
[:,
last_token_indices
]
else
:
positions
=
self
.
positions
[
last_token_indices
]
if
self
.
method
in
(
...
...
@@ -1128,6 +1128,7 @@ class SpecDecodeBaseProposer:
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen3VLForConditionalGeneration"
,
"Qwen3VLMoeForConditionalGeneration"
,
"GlmOcrForConditionalGeneration"
,
]:
self
.
model
.
config
.
image_token_index
=
target_model
.
config
.
image_token_id
elif
self
.
get_model_name
(
target_model
)
==
"PixtralForConditionalGeneration"
:
...
...
vllm/v1/structured_output/__init__.py
View file @
d76fc11e
...
...
@@ -74,9 +74,6 @@ class StructuredOutputManager:
self
.
tokenizer
=
cached_tokenizer_from_config
(
model_config
=
self
.
vllm_config
.
model_config
)
reasoning_parser
=
(
self
.
vllm_config
.
structured_outputs_config
.
reasoning_parser
)
reasoning_parser_plugin
=
(
self
.
vllm_config
.
structured_outputs_config
.
reasoning_parser_plugin
)
...
...
vllm/v1/worker/gpu/buffer_utils.py
View file @
d76fc11e
...
...
@@ -11,6 +11,26 @@ from vllm.utils.platform_utils import is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
def
async_copy_to_gpu
(
x
:
torch
.
Tensor
|
np
.
ndarray
,
out
:
torch
.
Tensor
|
None
=
None
,
device
:
torch
.
device
|
None
=
None
,
)
->
torch
.
Tensor
:
if
isinstance
(
x
,
np
.
ndarray
):
x
=
torch
.
from_numpy
(
x
)
assert
x
.
is_cpu
assert
not
x
.
is_pinned
()
if
out
is
None
:
assert
device
is
not
None
out
=
torch
.
empty_like
(
x
,
device
=
device
)
# CPU-to-CPU copy
tmp
=
x
.
pin_memory
()
# CPU-to-GPU copy
return
out
.
copy_
(
tmp
,
non_blocking
=
True
)
class
UvaBuffer
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
):
if
not
is_uva_available
():
...
...
vllm/v1/worker/gpu/mm/encoder_runner.py
View file @
d76fc11e
...
...
@@ -6,7 +6,6 @@ import torch
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalKwargsItem
from
vllm.multimodal.utils
import
group_mm_kwargs_by_modality
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.utils
import
sanity_check_mm_encoder_outputs
...
...
@@ -32,8 +31,6 @@ class EncoderRunner:
self
.
req_id_to_mm_features
:
dict
[
str
,
list
[
MultiModalFeatureSpec
]]
=
{}
self
.
encoder_cache
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
tmp_is_mm_embed
=
UvaBufferPool
(
max_num_tokens
,
torch
.
bool
)
def
add_request
(
self
,
req_id
:
str
,
mm_features
:
list
[
MultiModalFeatureSpec
]):
self
.
req_id_to_mm_features
[
req_id
]
=
mm_features
...
...
@@ -114,7 +111,7 @@ class EncoderRunner:
total_num_scheduled_tokens
,
dtype
=
torch
.
bool
,
device
=
"cpu"
,
pin_memory
=
Fals
e
,
pin_memory
=
Tru
e
,
)
for
i
,
req_id
in
enumerate
(
req_ids
):
if
not
is_prefilling
[
i
]:
...
...
@@ -163,7 +160,7 @@ class EncoderRunner:
mm_embeds
.
append
(
mm_embeds_item
)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed
=
self
.
tmp_
is_mm_embed
.
copy_to_gpu
(
is_mm_embed
)
is_mm_embed
=
is_mm_embed
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
return
mm_embeds
,
is_mm_embed
@
torch
.
inference_mode
()
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
d76fc11e
...
...
@@ -30,7 +30,7 @@ from vllm.v1.worker.gpu.attn_utils import (
init_kv_cache
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.gpu.buffer_utils
import
async_copy_to_gpu
from
vllm.v1.worker.gpu.cudagraph_utils
import
CudaGraphManager
from
vllm.v1.worker.gpu.dp_utils
import
(
get_cudagraph_and_dp_padding
,
...
...
@@ -172,11 +172,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# LoRA-related workers.
self
.
lora_state
=
LoraState
(
max_num_reqs
=
self
.
max_num_reqs
)
# Buffers for CPU-to-GPU copies.
self
.
tmp_idx_mapping
=
UvaBufferPool
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
tmp_cu_num_logits
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
self
.
tmp_query_start_loc
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
...
...
@@ -518,7 +513,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
req_states
.
req_id_to_index
[
req_id
]
for
req_id
in
req_ids
]
idx_mapping_np
=
np
.
array
(
idx_mapping_list
,
dtype
=
np
.
int32
)
idx_mapping
=
self
.
tmp_idx_mapping
.
copy_to_gpu
(
idx_mapping_np
)
idx_mapping
=
async_
copy_to_gpu
(
idx_mapping_np
,
device
=
self
.
device
)
# Get the number of draft tokens for each request.
if
not
scheduler_output
.
scheduled_spec_decode_tokens
:
...
...
@@ -546,7 +541,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_logits_np
=
np
.
empty
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
cu_num_logits_np
[
0
]
=
0
np
.
cumsum
(
num_logits
,
out
=
cu_num_logits_np
[
1
:])
cu_num_logits
=
self
.
tmp_cu_num_logits
.
copy_to_gpu
(
cu_num_logits_np
)
cu_num_logits
=
async_
copy_to_gpu
(
cu_num_logits_np
,
device
=
self
.
device
)
expanded_idx_mapping
=
expand_idx_mapping
(
idx_mapping
,
...
...
@@ -565,10 +560,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np
[
num_reqs
+
1
:]
=
num_tokens
self
.
tmp_query_start_loc
.
copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
,
)
async_copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
)
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
...
...
Prev
1
…
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment