Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eefa41c1
Commit
eefa41c1
authored
Mar 24, 2026
by
zhuwenwen
Browse files
sync v0.18.0
parent
82155c76
Changes
253
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1046 additions
and
128 deletions
+1046
-128
tests/models/registry.py
tests/models/registry.py
+15
-13
tests/models/test_transformers.py
tests/models/test_transformers.py
+2
-2
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+2
-2
tests/test_access_log_filter.py
tests/test_access_log_filter.py
+371
-0
tests/v1/e2e/spec_decode/test_spec_decode.py
tests/v1/e2e/spec_decode/test_spec_decode.py
+2
-2
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+3
-3
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+15
-46
vllm/_custom_ops.py
vllm/_custom_ops.py
+26
-11
vllm/_xpu_ops.py
vllm/_xpu_ops.py
+297
-1
vllm/config/compilation.py
vllm/config/compilation.py
+2
-1
vllm/config/speculative.py
vllm/config/speculative.py
+10
-0
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+84
-3
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+44
-5
vllm/distributed/device_communicators/cpu_communicator.py
vllm/distributed/device_communicators/cpu_communicator.py
+43
-7
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+40
-6
vllm/distributed/device_communicators/mnnvl_compat.py
vllm/distributed/device_communicators/mnnvl_compat.py
+2
-0
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+44
-8
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+4
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+15
-13
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+25
-3
No files found.
tests/models/registry.py
View file @
eefa41c1
...
...
@@ -266,7 +266,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
),
"Exaone4ForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/EXAONE-4.0-32B"
),
"ExaoneMoEForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/K-EXAONE-236B-A23B"
,
min_transformers_version
=
"5.
0
.0"
"LGAI-EXAONE/K-EXAONE-236B-A23B"
,
min_transformers_version
=
"5.
1
.0"
),
"Fairseq2LlamaForCausalLM"
:
_HfExamplesInfo
(
"mgleize/fairseq2-dummy-Llama-3.2-1B"
),
"FalconForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-7b"
),
...
...
@@ -283,11 +283,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Glm4MoeForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-4.5"
),
"Glm4MoeLiteForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-4.7-Flash"
,
min_transformers_version
=
"5.0.0.dev"
,
is_available_online
=
False
,
),
"GlmMoeDsaForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-5"
,
min_transformers_version
=
"5.0.1"
,
is_available_online
=
False
min_transformers_version
=
"5.0.0"
,
),
"GlmMoeDsaForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-5"
,
min_transformers_version
=
"5.0.1"
,
is_available_online
=
False
...
...
@@ -743,7 +739,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
),
"AudioFlamingo3ForConditionalGeneration"
:
_HfExamplesInfo
(
"nvidia/audio-flamingo-3-hf"
,
min_transformers_version
=
"5.0.0
.dev
"
"nvidia/audio-flamingo-3-hf"
,
min_transformers_version
=
"5.0.0"
),
"MusicFlamingoForConditionalGeneration"
:
_HfExamplesInfo
(
"nvidia/music-flamingo-2601-hf"
,
min_transformers_version
=
"5.0.0.dev"
...
...
@@ -1237,7 +1233,13 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeLiteMTPModel"
:
_HfExamplesInfo
(
"zai-org/GLM-4.7-Flash"
,
speculative_model
=
"zai-org/GLM-4.7-Flash"
,
min_transformers_version
=
"5.0.0"
,
),
"GlmOcrMTPModel"
:
_HfExamplesInfo
(
"zai-org/GLM-OCR"
,
speculative_model
=
"zai-org/GLM-OCR"
,
is_available_online
=
False
,
min_transformers_version
=
"5.1.0"
,
),
"LongCatFlashMTPModel"
:
_HfExamplesInfo
(
"meituan-longcat/LongCat-Flash-Chat"
,
...
...
@@ -1282,27 +1284,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_BACKEND_MODELS
=
{
"TransformersEmbeddingModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
,
min_transformers_version
=
"5.0.0
.dev
"
"BAAI/bge-base-en-v1.5"
,
min_transformers_version
=
"5.0.0"
),
"TransformersForSequenceClassification"
:
_HfExamplesInfo
(
"papluca/xlm-roberta-base-language-detection"
,
min_transformers_version
=
"5.0.0
.dev
"
,
min_transformers_version
=
"5.0.0"
,
),
"TransformersForCausalLM"
:
_HfExamplesInfo
(
"hmellor/Ilama-3.2-1B"
,
trust_remote_code
=
True
),
"TransformersMultiModalForCausalLM"
:
_HfExamplesInfo
(
"BAAI/Emu3-Chat-hf"
),
"TransformersMoEForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924"
,
min_transformers_version
=
"5.0.0
.dev
"
"allenai/OLMoE-1B-7B-0924"
,
min_transformers_version
=
"5.0.0"
),
"TransformersMultiModalMoEForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
min_transformers_version
=
"5.0.0
.dev
"
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
min_transformers_version
=
"5.0.0"
),
"TransformersMoEEmbeddingModel"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
,
min_transformers_version
=
"5.0.0
.dev
"
"Qwen/Qwen3-30B-A3B"
,
min_transformers_version
=
"5.0.0"
),
"TransformersMoEForSequenceClassification"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
,
min_transformers_version
=
"5.0.0
.dev
"
"Qwen/Qwen3-30B-A3B"
,
min_transformers_version
=
"5.0.0"
),
"TransformersMultiModalEmbeddingModel"
:
_HfExamplesInfo
(
"google/gemma-3-4b-it"
),
"TransformersMultiModalForSequenceClassification"
:
_HfExamplesInfo
(
...
...
tests/models/test_transformers.py
View file @
eefa41c1
...
...
@@ -76,7 +76,7 @@ def test_models(
from
packaging.version
import
Version
installed
=
Version
(
transformers
.
__version__
)
required
=
Version
(
"5.0.0
.dev
"
)
required
=
Version
(
"5.0.0"
)
if
model
==
"allenai/OLMoE-1B-7B-0924"
and
installed
<
required
:
pytest
.
skip
(
"MoE models with the Transformers modeling backend require "
...
...
@@ -237,4 +237,4 @@ def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
)
\ No newline at end of file
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
eefa41c1
...
...
@@ -36,7 +36,7 @@ class MyGemma2Embedding(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -59,4 +59,4 @@ class MyGemma2Embedding(nn.Module):
weights
=
(
(
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
)
)
return
self
.
model
.
load_weights
(
weights
)
return
self
.
model
.
load_weights
(
weights
)
\ No newline at end of file
tests/test_access_log_filter.py
0 → 100644
View file @
eefa41c1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the UvicornAccessLogFilter class.
"""
import
logging
from
vllm.logging_utils.access_log_filter
import
(
UvicornAccessLogFilter
,
create_uvicorn_log_config
,
)
class
TestUvicornAccessLogFilter
:
"""Test cases for UvicornAccessLogFilter."""
def
test_filter_allows_all_when_no_excluded_paths
(
self
):
"""Filter should allow all logs when no paths are excluded."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/v1/completions"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
True
def
test_filter_allows_all_when_excluded_paths_is_none
(
self
):
"""Filter should allow all logs when excluded_paths is None."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
None
)
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
True
def
test_filter_excludes_health_endpoint
(
self
):
"""Filter should exclude /health endpoint when configured."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
False
def
test_filter_excludes_metrics_endpoint
(
self
):
"""Filter should exclude /metrics endpoint when configured."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/metrics"
])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/metrics"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
False
def
test_filter_allows_non_excluded_endpoints
(
self
):
"""Filter should allow endpoints not in the excluded list."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
,
"/metrics"
])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"POST"
,
"/v1/completions"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
True
def
test_filter_excludes_multiple_endpoints
(
self
):
"""Filter should exclude multiple configured endpoints."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
,
"/metrics"
,
"/ping"
])
# Test /health
record_health
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_health
)
is
False
# Test /metrics
record_metrics
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/metrics"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_metrics
)
is
False
# Test /ping
record_ping
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/ping"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_ping
)
is
False
def
test_filter_with_query_parameters
(
self
):
"""Filter should exclude endpoints even with query parameters."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health?verbose=true"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
False
def
test_filter_different_http_methods
(
self
):
"""Filter should exclude endpoints regardless of HTTP method."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/ping"
])
# Test GET
record_get
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/ping"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_get
)
is
False
# Test POST
record_post
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"POST"
,
"/ping"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_post
)
is
False
def
test_filter_with_different_status_codes
(
self
):
"""Filter should exclude endpoints regardless of status code."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
for
status_code
in
[
200
,
500
,
503
]:
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
status_code
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
False
class
TestCreateUvicornLogConfig
:
"""Test cases for create_uvicorn_log_config function."""
def
test_creates_valid_config_structure
(
self
):
"""Config should have required logging configuration keys."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[
"/health"
])
assert
"version"
in
config
assert
config
[
"version"
]
==
1
assert
"disable_existing_loggers"
in
config
assert
"formatters"
in
config
assert
"handlers"
in
config
assert
"loggers"
in
config
assert
"filters"
in
config
def
test_config_includes_access_log_filter
(
self
):
"""Config should include the access log filter."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[
"/health"
,
"/metrics"
])
assert
"access_log_filter"
in
config
[
"filters"
]
filter_config
=
config
[
"filters"
][
"access_log_filter"
]
assert
filter_config
[
"()"
]
==
UvicornAccessLogFilter
assert
filter_config
[
"excluded_paths"
]
==
[
"/health"
,
"/metrics"
]
def
test_config_applies_filter_to_access_handler
(
self
):
"""Config should apply the filter to the access handler."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[
"/health"
])
assert
"access"
in
config
[
"handlers"
]
assert
"filters"
in
config
[
"handlers"
][
"access"
]
assert
"access_log_filter"
in
config
[
"handlers"
][
"access"
][
"filters"
]
def
test_config_with_custom_log_level
(
self
):
"""Config should respect custom log level."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[
"/health"
],
log_level
=
"debug"
)
assert
config
[
"loggers"
][
"uvicorn"
][
"level"
]
==
"DEBUG"
assert
config
[
"loggers"
][
"uvicorn.access"
][
"level"
]
==
"DEBUG"
assert
config
[
"loggers"
][
"uvicorn.error"
][
"level"
]
==
"DEBUG"
def
test_config_with_empty_excluded_paths
(
self
):
"""Config should work with empty excluded paths."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[])
assert
config
[
"filters"
][
"access_log_filter"
][
"excluded_paths"
]
==
[]
def
test_config_with_none_excluded_paths
(
self
):
"""Config should work with None excluded paths."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
None
)
assert
config
[
"filters"
][
"access_log_filter"
][
"excluded_paths"
]
==
[]
class
TestIntegration
:
"""Integration tests for the access log filter."""
def
test_filter_with_real_logger
(
self
):
"""Test filter works with a real Python logger simulating uvicorn."""
# Create a logger with our filter (simulating uvicorn.access)
logger
=
logging
.
getLogger
(
"uvicorn.access"
)
logger
.
setLevel
(
logging
.
INFO
)
# Clear any existing handlers
logger
.
handlers
=
[]
# Create a custom handler that tracks messages
logged_messages
:
list
[
str
]
=
[]
class
TrackingHandler
(
logging
.
Handler
):
def
emit
(
self
,
record
):
logged_messages
.
append
(
record
.
getMessage
())
handler
=
TrackingHandler
()
handler
.
setLevel
(
logging
.
INFO
)
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
,
"/metrics"
])
handler
.
addFilter
(
filter
)
logger
.
addHandler
(
handler
)
# Log using uvicorn's format with args tuple
# Format: '%s - "%s %s HTTP/%s" %d'
logger
.
info
(
'%s - "%s %s HTTP/%s" %d'
,
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
200
,
)
logger
.
info
(
'%s - "%s %s HTTP/%s" %d'
,
"127.0.0.1:12345"
,
"GET"
,
"/v1/completions"
,
"1.1"
,
200
,
)
logger
.
info
(
'%s - "%s %s HTTP/%s" %d'
,
"127.0.0.1:12345"
,
"GET"
,
"/metrics"
,
"1.1"
,
200
,
)
logger
.
info
(
'%s - "%s %s HTTP/%s" %d'
,
"127.0.0.1:12345"
,
"POST"
,
"/v1/chat/completions"
,
"1.1"
,
200
,
)
# Verify only non-excluded endpoints were logged
assert
len
(
logged_messages
)
==
2
assert
"/v1/completions"
in
logged_messages
[
0
]
assert
"/v1/chat/completions"
in
logged_messages
[
1
]
def
test_filter_allows_non_uvicorn_access_logs
(
self
):
"""Test filter allows logs from non-uvicorn.access loggers."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
# Log record from a different logger name
record
=
logging
.
LogRecord
(
name
=
"uvicorn.error"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
"Some error message about /health"
,
args
=
(),
exc_info
=
None
,
)
# Should allow because it's not from uvicorn.access
assert
filter
.
filter
(
record
)
is
True
def
test_filter_handles_malformed_args
(
self
):
"""Test filter handles log records with unexpected args format."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
# Log record with insufficient args
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
"Some message"
,
args
=
(
"only"
,
"two"
),
exc_info
=
None
,
)
# Should allow because args doesn't have expected format
assert
filter
.
filter
(
record
)
is
True
def
test_filter_handles_non_tuple_args
(
self
):
"""Test filter handles log records with non-tuple args."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
# Log record with None args
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
"Some message without args"
,
args
=
None
,
exc_info
=
None
,
)
# Should allow because args is None
assert
filter
.
filter
(
record
)
is
True
\ No newline at end of file
tests/v1/e2e/spec_decode/test_spec_decode.py
View file @
eefa41c1
...
...
@@ -383,7 +383,7 @@ def _run_eagle_correctness(
from
packaging.version
import
Version
installed
=
Version
(
transformers
.
__version__
)
required
=
Version
(
"5.0.0
.dev
"
)
required
=
Version
(
"5.0.0"
)
if
installed
<
required
:
pytest
.
skip
(
"Eagle3 with the Transformers modeling backend requires "
...
...
@@ -1030,4 +1030,4 @@ def compute_acceptance_len(metrics: list[Metric]) -> float:
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
if
n_drafts
==
0
:
return
1
return
1
+
(
n_accepted_toks
/
n_drafts
)
return
1
+
(
n_accepted_toks
/
n_drafts
)
\ No newline at end of file
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
eefa41c1
...
...
@@ -59,9 +59,9 @@ fi
# Build the kv-transfer-config once
if
[[
"
$KV_BUFFER_DEVICE
"
==
"cuda"
]]
;
then
KV_CONFIG
=
'{"kv_connector":"NixlConnector","kv_role":"kv_both"'
${
KV_CONFIG_HETERO_LAYOUT
}
'}'
KV_CONFIG
=
'{"kv_connector":"NixlConnector","kv_role":"kv_both"'
${
KV_CONFIG_HETERO_LAYOUT
}
${
KV_EXTRA_CONFIG
}
'}'
else
KV_CONFIG
=
"{
\"
kv_connector
\"
:
\"
NixlConnector
\"
,
\"
kv_role
\"
:
\"
kv_both
\"
,
\"
kv_buffer_device
\"
:
\"
$KV_BUFFER_DEVICE
\"
"
${
KV_CONFIG_HETERO_LAYOUT
}
"}"
KV_CONFIG
=
"{
\"
kv_connector
\"
:
\"
NixlConnector
\"
,
\"
kv_role
\"
:
\"
kv_both
\"
,
\"
kv_buffer_device
\"
:
\"
$KV_BUFFER_DEVICE
\"
"
${
KV_CONFIG_HETERO_LAYOUT
}
${
KV_EXTRA_CONFIG
}
"}"
fi
# Models to run
...
...
@@ -295,4 +295,4 @@ for model in "${MODELS[@]}"; do
run_tests_for_model
"
$model
"
done
echo
"All tests completed!"
echo
"All tests completed!"
\ No newline at end of file
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
eefa41c1
...
...
@@ -18,8 +18,12 @@ import ray
import
torch
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.config
import
KVTransferConfig
,
set_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
KVOutputAggregator
,
TpKVTopology
,
get_current_attn_backend
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
nixl_connector
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.distributed.kv_transfer.kv_connector.v1.multi_connector
import
(
...
...
@@ -58,6 +62,8 @@ from vllm.v1.kv_cache_interface import (
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.utils
import
AttentionGroup
from
.utils
import
(
create_request
,
...
...
@@ -1498,44 +1504,6 @@ def test_register_kv_caches(
backend_cls
=
TritonAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
}
# Store tensor info for validation
test_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
)
is_blocks_first
=
len
(
test_shape
)
==
5
and
test_shape
[
0
]
==
1
if
is_blocks_first
:
expected_tensor_size
=
shared_tensor
.
element_size
()
*
shared_tensor
.
numel
()
expected_base_addrs
=
[
shared_tensor
.
data_ptr
(),
unique_tensor
.
data_ptr
(),
]
expected_num_entries
=
2
else
:
expected_tensor_size
=
(
shared_tensor
[
0
].
element_size
()
*
shared_tensor
[
0
].
numel
()
)
expected_base_addrs
=
[
shared_tensor
[
0
].
data_ptr
(),
shared_tensor
[
1
].
data_ptr
(),
unique_tensor
[
0
].
data_ptr
(),
unique_tensor
[
1
].
data_ptr
(),
]
expected_num_entries
=
4
nixl_module
=
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with
(
patch
(
f
"
{
nixl_module
}
.NixlWrapper"
)
as
mock_nixl_wrapper
,
...
...
@@ -1716,14 +1684,13 @@ def test_register_kv_caches(
blocks_data
,
_
=
mock_wrapper_instance
.
get_xfer_descs
.
call_args
[
0
]
# Validate blocks_data structure and size
expected_blocks_count
=
8
assert
len
(
blocks_data
)
==
expected_blocks_count
,
(
f
"Expected
{
expected_blocks_count
}
blocks, got
{
len
(
blocks_data
)
}
"
)
num
_blocks
=
2
if
is_blocks_first
:
expected_block_len
=
expected_tensor_size
//
num_blocks
//
2
if
connector
.
prefer_cross_layer
_blocks
:
num_blocks
=
8
expected_block_len
=
expected_tensor_size
//
num_blocks
else
:
num_blocks
=
kv_cache_config
.
num_blocks
if
is_blocks_first
:
...
...
@@ -2360,7 +2327,9 @@ def test_compatibility_hash_validation(
)
)
remote_hash
=
compute_nixl_compatibility_hash
(
remote_vllm_config
,
decode_worker
.
backend_name
remote_vllm_config
,
decode_worker
.
backend_name
,
decode_worker
.
kv_topo
.
cross_layers_blocks
,
)
prefill_block_size
=
config_overrides
.
get
(
"block_size"
,
16
)
...
...
@@ -2497,4 +2466,4 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
port
=
1234
,
remote_tp_size
=
1
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
)
)
\ No newline at end of file
vllm/_custom_ops.py
View file @
eefa41c1
...
...
@@ -3044,13 +3044,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
class
CPUDNNLGEMMHandler
:
def
__init__
(
self
)
->
None
:
self
.
handler
:
int
|
None
=
None
self
.
handler
_tensor
:
torch
.
Tensor
|
None
=
None
self
.
n
=
-
1
self
.
k
=
-
1
def
__del__
(
self
):
if
self
.
handler
is
not
None
:
torch
.
ops
.
_C
.
release_dnnl_matmul_handler
(
self
.
handler
)
if
self
.
handler
_tensor
is
not
None
:
torch
.
ops
.
_C
.
release_dnnl_matmul_handler
(
self
.
handler
_tensor
.
item
()
)
_supports_onednn
=
bool
(
hasattr
(
torch
.
ops
.
_C
,
"create_onednn_mm_handler"
))
...
...
@@ -3066,8 +3066,10 @@ def create_onednn_mm(
)
->
CPUDNNLGEMMHandler
:
handler
=
CPUDNNLGEMMHandler
()
handler
.
k
,
handler
.
n
=
weight
.
size
()
handler
.
handler
=
torch
.
ops
.
_C
.
create_onednn_mm_handler
(
weight
,
primitive_cache_size
# store the handler pointer in a tensor it doesn't get inlined
handler
.
handler_tensor
=
torch
.
tensor
(
torch
.
ops
.
_C
.
create_onednn_mm_handler
(
weight
,
primitive_cache_size
),
dtype
=
torch
.
int64
,
)
return
handler
...
...
@@ -3079,7 +3081,7 @@ def onednn_mm(
)
->
torch
.
Tensor
:
output
=
torch
.
empty
((
*
x
.
shape
[
0
:
-
1
],
dnnl_handler
.
n
),
dtype
=
x
.
dtype
)
torch
.
ops
.
_C
.
onednn_mm
(
output
,
x
.
reshape
(
-
1
,
dnnl_handler
.
k
),
bias
,
dnnl_handler
.
handler
output
,
x
.
reshape
(
-
1
,
dnnl_handler
.
k
),
bias
,
dnnl_handler
.
handler
_tensor
)
return
output
...
...
@@ -3095,8 +3097,17 @@ def create_onednn_scaled_mm(
)
->
CPUDNNLGEMMHandler
:
handler
=
CPUDNNLGEMMHandler
()
handler
.
k
,
handler
.
n
=
weight
.
size
()
handler
.
handler
=
torch
.
ops
.
_C
.
create_onednn_scaled_mm_handler
(
weight
,
weight_scales
,
output_type
,
dynamic_quant
,
use_azp
,
primitive_cache_size
# store the handler pointer in a tensor so it doesn't get inlined
handler
.
handler_tensor
=
torch
.
tensor
(
torch
.
ops
.
_C
.
create_onednn_scaled_mm_handler
(
weight
,
weight_scales
,
output_type
,
dynamic_quant
,
use_azp
,
primitive_cache_size
,
),
dtype
=
torch
.
int64
,
)
return
handler
...
...
@@ -3149,11 +3160,15 @@ def onednn_scaled_mm(
bias
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
torch
.
ops
.
_C
.
onednn_scaled_mm
(
output
,
x
,
input_scale
,
input_zp
,
input_zp_adj
,
bias
,
dnnl_handler
.
handler
output
,
x
,
input_scale
,
input_zp
,
input_zp_adj
,
bias
,
dnnl_handler
.
handler_tensor
,
)
return
output
def
cpu_attn_get_scheduler_metadata
(
num_reqs
:
int
,
...
...
vllm/_xpu_ops.py
View file @
eefa41c1
...
...
@@ -7,6 +7,8 @@ import torch
from
vllm_xpu_kernels.flash_attn_interface
import
flash_attn_varlen_func
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
...
...
@@ -53,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
return
torch
.
empty
((
M
,
N
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
_xpu_ops_deepseek_scaling_rope_impl
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
|
None
,
offsets
:
torch
.
Tensor
|
None
,
cos_sin_cache
:
torch
.
Tensor
|
None
,
rotary_dim
:
int
,
is_neox_style
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
key
is
not
None
return
torch
.
ops
.
_xpu_C
.
deepseek_scaling_rope
(
positions
,
query
,
key
,
offsets
,
cos_sin_cache
,
rotary_dim
,
is_neox_style
)
def
_xpu_ops_deepseek_scaling_rope_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
|
None
,
offsets
:
torch
.
Tensor
|
None
,
cos_sin_cache
:
torch
.
Tensor
|
None
,
rotary_dim
:
int
,
is_neox_style
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
query
,
key
# Global flag to ensure ops are registered only once
_OPS_REGISTERED
=
False
class
xpu_ops
:
@
staticmethod
def
flash_attn_varlen_func
(
...
...
@@ -105,9 +138,10 @@ class xpu_ops:
assert
len
(
window_size
)
==
2
real_window_size
=
(
window_size
[
0
],
window_size
[
1
])
# noqa: F841
# In encode attention, v maybe not contiguous and current
# In encode attention,
k and
v maybe not contiguous and current
# kernel can't handle it
if
block_table
is
None
:
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
return
flash_attn_varlen_func
(
out
=
out
,
...
...
@@ -156,3 +190,265 @@ class xpu_ops:
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
)
return
None
@
staticmethod
def
indexer_k_quant_and_cache
(
k
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
)
->
None
:
head_dim
=
k
.
shape
[
-
1
]
k
=
k
.
view
(
-
1
,
head_dim
)
# [total_tokens, head_dim]
def
group_quant_torch
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
|
None
=
None
,
column_major_scales
:
bool
=
False
,
out_q
:
torch
.
Tensor
|
None
=
None
,
use_ue8m0
:
bool
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
use_ue8m0
is
None
:
# Default fallback - could import is_deep_gemm_e8m0_used if needed
use_ue8m0
=
False
if
dtype
is
None
:
dtype
=
current_platform
.
fp8_dtype
()
# Validate inputs
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
(
f
"Last dimension
{
x
.
shape
[
-
1
]
}
must be divisible by "
f
"group_size
{
group_size
}
"
)
assert
x
.
stride
(
-
1
)
==
1
,
"Input tensor groups must be contiguous"
# Prepare output tensor
if
out_q
is
None
:
x_q
=
torch
.
empty_like
(
x
,
dtype
=
dtype
)
else
:
assert
out_q
.
shape
==
x
.
shape
x_q
=
out_q
# Reshape input for group processing
# Original shape: (..., last_dim)
# Target shape: (..., num_groups, group_size)
original_shape
=
x
.
shape
num_groups
=
original_shape
[
-
1
]
//
group_size
# Reshape to separate groups
group_shape
=
original_shape
[:
-
1
]
+
(
num_groups
,
group_size
)
x_grouped
=
x
.
view
(
group_shape
)
# Compute per-group absolute maximum values
# Shape: (..., num_groups)
abs_max
=
torch
.
amax
(
torch
.
abs
(
x_grouped
),
dim
=-
1
,
keepdim
=
False
)
abs_max
=
torch
.
maximum
(
abs_max
,
torch
.
tensor
(
eps
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
)
# Compute scales
FP8_MAX
=
torch
.
finfo
(
dtype
).
max
FP8_MIN
=
torch
.
finfo
(
dtype
).
min
scale_raw
=
abs_max
/
FP8_MAX
if
use_ue8m0
:
# For UE8M0 format, scales must be powers of 2
scales
=
torch
.
pow
(
2.0
,
torch
.
ceil
(
torch
.
log2
(
scale_raw
)))
else
:
scales
=
scale_raw
# Expand scales for broadcasting with grouped data
# Shape: (..., num_groups, 1)
scales_expanded
=
scales
.
unsqueeze
(
-
1
)
# Quantize the grouped data
x_scaled
=
x_grouped
/
scales_expanded
x_clamped
=
torch
.
clamp
(
x_scaled
,
FP8_MIN
,
FP8_MAX
)
x_quantized
=
x_clamped
.
to
(
dtype
)
# Reshape back to original shape
x_q
.
copy_
(
x_quantized
.
view
(
original_shape
))
# Prepare scales tensor in requested format
if
column_major_scales
:
# Column-major: (num_groups,) + batch_dims
# Transpose the scales to put group dimension first
scales_shape
=
(
num_groups
,)
+
original_shape
[:
-
1
]
x_s
=
scales
.
permute
(
-
1
,
*
range
(
len
(
original_shape
)
-
1
))
x_s
=
x_s
.
contiguous
().
view
(
scales_shape
)
else
:
# Row-major: batch_dims + (num_groups,)
x_s
=
scales
.
contiguous
()
# Ensure scales are float32
return
x_q
,
x_s
.
float
()
k_fp8
,
k_scale
=
group_quant_torch
(
k
,
group_size
=
quant_block_size
,
column_major_scales
=
False
,
use_ue8m0
=
(
scale_fmt
==
"ue8m0"
),
)
k_fp8_bytes
=
k_fp8
.
view
(
-
1
,
head_dim
).
view
(
torch
.
uint8
)
scale_bytes
=
k_scale
.
view
(
torch
.
uint8
).
view
(
-
1
,
4
)
k
=
torch
.
cat
(
[
k_fp8_bytes
,
scale_bytes
],
dim
=-
1
)
# [total_tokens, head_dim + 4]
slot_mapping
=
slot_mapping
.
flatten
()
# kv_cache: [num_block, block_size, head_dim + 4]
kv_cache
.
view
(
-
1
,
kv_cache
.
shape
[
-
1
]).
index_copy_
(
0
,
slot_mapping
,
k
)
@
staticmethod
def
cp_gather_indexer_k_quant_cache
(
kv_cache
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
dst_scale
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
)
->
None
:
"""
Args:
kv_cache: [num_blocks, block_size, cache_stride] - quantized KV cache
Layout per block: [k_values, scale_values]
- k_values: [block_size * head_dim]
- scale_values: [block_size * head_dim * 4 / quant_block_size]
dst_k: [num_tokens, head_dim] - output tensor for K values
dst_scale: [num_tokens, head_dim / quant_block_size * 4]
- output tensor for scale values
block_table: [batch_size, num_blocks] - block table for indexing
cu_seq_lens: [batch_size + 1] - cumulative sequence lengths
"""
batch_size
=
block_table
.
size
(
0
)
num_tokens
=
dst_k
.
size
(
0
)
head_dim
=
dst_k
.
size
(
1
)
cache_block_size
=
kv_cache
.
size
(
1
)
quant_block_size
=
head_dim
*
4
//
dst_scale
.
size
(
1
)
# For each token, find which batch it belongs to using searchsorted
token_indices
=
torch
.
arange
(
num_tokens
,
device
=
dst_k
.
device
)
+
1
# cu_seq_lens is [batch_size + 1], we need to find which interval each
# token belongs to
batch_indices
=
torch
.
searchsorted
(
cu_seq_lens
,
token_indices
)
-
1
batch_indices
=
torch
.
clamp
(
batch_indices
,
0
,
batch_size
-
1
)
# Calculate the in-batch sequence index for each token
inbatch_seq_indices
=
token_indices
-
cu_seq_lens
[
batch_indices
]
# Find which block each token belongs to
block_indices_in_table
=
inbatch_seq_indices
//
cache_block_size
physical_block_indices
=
block_table
[
batch_indices
,
block_indices_in_table
]
# Calculate the offset within each block
inblock_offsets
=
(
inbatch_seq_indices
-
1
)
%
cache_block_size
# Calculate strides
block_stride
=
kv_cache
.
stride
(
0
)
# stride for each block
# Flatten kv_cache for easier indexing
kv_cache_flat
=
kv_cache
.
view
(
-
1
)
# Calculate source offset for K values for all tokens (vectorized)
src_block_offsets
=
physical_block_indices
*
block_stride
src_k_offsets
=
src_block_offsets
+
inblock_offsets
*
head_dim
# Gather K values using advanced indexing
# Create indices for all elements we need to gather
k_indices
=
src_k_offsets
.
unsqueeze
(
1
)
+
torch
.
arange
(
head_dim
,
device
=
dst_k
.
device
)
dst_k
[:]
=
kv_cache_flat
[
k_indices
]
# Calculate source offset for scale values (vectorized)
# Scales are stored after all K values for each block
scale_size
=
head_dim
*
4
//
quant_block_size
src_scale_offsets
=
src_block_offsets
+
head_dim
+
inblock_offsets
*
scale_size
# Gather scale values
scale_indices
=
src_scale_offsets
.
unsqueeze
(
1
)
+
torch
.
arange
(
scale_size
,
device
=
dst_scale
.
device
)
dst_scale
[:]
=
kv_cache_flat
[
scale_indices
]
@
staticmethod
def
top_k_per_row_prefill
(
logits
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
raw_topk_indices
:
torch
.
Tensor
,
num_rows
:
int
,
stride0
:
int
,
strdide1
:
int
,
topk_tokens
:
int
,
)
->
torch
.
Tensor
:
real_topk
=
min
(
topk_tokens
,
logits
.
shape
[
-
1
])
topk_indices
=
logits
.
topk
(
real_topk
,
dim
=-
1
)[
1
].
to
(
torch
.
int32
)
topk_indices
-=
cu_seqlen_ks
[:,
None
]
mask_lo
=
topk_indices
>=
0
mask_hi
=
topk_indices
-
(
cu_seqlen_ke
-
cu_seqlen_ks
)[:,
None
]
<
0
mask
=
torch
.
full_like
(
topk_indices
,
False
,
dtype
=
torch
.
bool
,
device
=
topk_indices
.
device
)
mask
=
mask_lo
&
mask_hi
topk_indices
.
masked_fill_
(
~
mask
,
-
1
)
raw_topk_indices
[:
topk_indices
.
shape
[
0
],
:
topk_indices
.
shape
[
1
]]
=
(
topk_indices
)
@
staticmethod
def
top_k_per_row_decode
(
logits
:
torch
.
Tensor
,
next_n
:
int
,
seq_lens
:
torch
.
Tensor
,
raw_topk_indices
:
torch
.
Tensor
,
num_rows
:
int
,
stride0
:
int
,
stride1
:
int
,
topk_tokens
:
int
,
)
->
torch
.
Tensor
:
device
=
logits
.
device
batch_size
=
seq_lens
.
size
(
0
)
# padded query len
padded_num_tokens
=
batch_size
*
next_n
positions
=
(
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
device
)
.
unsqueeze
(
0
)
.
expand
(
batch_size
*
next_n
,
-
1
)
)
row_indices
=
torch
.
arange
(
padded_num_tokens
,
device
=
device
)
//
next_n
next_n_offset
=
torch
.
arange
(
padded_num_tokens
,
device
=
device
)
%
next_n
index_end_pos
=
(
seq_lens
[
row_indices
]
-
next_n
+
next_n_offset
).
unsqueeze
(
1
)
# index_end_pos: [B * N, 1]
mask
=
positions
<=
index_end_pos
# mask: [B * N, L]
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
"-inf"
))
topk_indices
=
logits
.
topk
(
topk_tokens
,
dim
=-
1
)[
1
].
to
(
torch
.
int32
)
# [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices
[
topk_indices
>
index_end_pos
]
=
-
1
raw_topk_indices
[:
topk_indices
.
shape
[
0
],
:
topk_indices
.
shape
[
1
]]
=
(
topk_indices
)
@
staticmethod
def
register_ops_once
()
->
None
:
global
_OPS_REGISTERED
if
not
_OPS_REGISTERED
:
# register all the custom ops here
direct_register_custom_op
(
op_name
=
"xpu_ops_deepseek_scaling_rope"
,
op_func
=
_xpu_ops_deepseek_scaling_rope_impl
,
mutates_args
=
[],
fake_impl
=
_xpu_ops_deepseek_scaling_rope_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
_OPS_REGISTERED
=
True
xpu_ops
.
register_ops_once
()
\ No newline at end of file
vllm/config/compilation.py
View file @
eefa41c1
...
...
@@ -337,9 +337,10 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
assume_32_bit_indexing
:
bool
=
Tru
e
assume_32_bit_indexing
:
bool
=
Fals
e
"""
whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
"""
def
compute_hash
(
self
)
->
str
:
...
...
vllm/config/speculative.py
View file @
eefa41c1
...
...
@@ -259,6 +259,16 @@ class SpeculativeConfig:
}
)
if
hf_config
.
architectures
[
0
]
==
"GlmOcrForConditionalGeneration"
:
hf_config
.
model_type
=
"glm_ocr_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
(
{
"num_hidden_layers"
:
0
,
"n_predict"
:
n_predict
,
"architectures"
:
[
"GlmOcrMTPModel"
],
}
)
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
hf_config
.
model_type
=
"ernie_mtp"
if
hf_config
.
model_type
==
"ernie_mtp"
:
...
...
vllm/distributed/device_communicators/all2all.py
View file @
eefa41c1
...
...
@@ -72,7 +72,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return
buffer
def
dispatch
(
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -96,6 +96,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
)
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
extra_tensors
is
not
None
:
raise
NotImplementedError
(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size
=
self
.
tp_group
.
world_size
if
is_sequence_parallel
else
1
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
cu_tokens_across_sp_cpu
=
dp_metadata
.
cu_tokens_across_sp
(
sp_size
)
hidden_states
=
self
.
naive_multicast
(
hidden_states
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_weights
=
self
.
naive_multicast
(
topk_weights
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_ids
=
self
.
naive_multicast
(
topk_ids
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
...
...
@@ -127,7 +155,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def
__init__
(
self
,
cpu_group
,
tcp_store_group
=
None
):
super
().
__init__
(
cpu_group
,
tcp_store_group
)
def
dispatch
(
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -161,6 +189,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
sizes
=
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
assert
sizes
is
not
None
dist_group
=
get_ep_group
()
if
is_sequence_parallel
else
get_dp_group
()
assert
sizes
[
dist_group
.
rank_in_group
]
==
hidden_states
.
shape
[
0
]
tensors_to_gather
=
[
hidden_states
,
topk_weights
,
topk_ids
]
if
extra_tensors
is
not
None
:
tensors_to_gather
.
extend
(
extra_tensors
)
gathered_tensors
=
dist_group
.
all_gatherv
(
tensors_to_gather
,
dim
=
0
,
sizes
=
sizes
,
)
hidden_states
=
gathered_tensors
[
0
]
topk_weights
=
gathered_tensors
[
1
]
topk_ids
=
gathered_tensors
[
2
]
if
extra_tensors
is
None
:
return
hidden_states
,
topk_weights
,
topk_ids
return
hidden_states
,
topk_weights
,
topk_ids
,
gathered_tensors
[
3
:]
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -200,7 +268,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
def
dispatch
(
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -209,6 +277,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
raise
NotImplementedError
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
vllm/distributed/device_communicators/base_device_communicator.py
View file @
eefa41c1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
threading
from
typing
import
Any
from
weakref
import
WeakValueDictionary
import
torch
...
...
@@ -70,13 +69,32 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise
NotImplementedError
def
dispatch
(
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
Any
:
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
...
...
@@ -312,7 +330,7 @@ class DeviceCommunicatorBase:
for
module
in
moe_modules
:
module
.
maybe_init_modular_kernel
()
def
dispatch
(
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -326,8 +344,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
router_logits
,
extra_tensors
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
topk_weights
,
topk_ids
,
extra_tensors
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -338,4 +377,4 @@ class DeviceCommunicatorBase:
return
hidden_states
def
batch_isend_irecv
(
self
,
p2p_ops
:
list
):
raise
NotImplementedError
raise
NotImplementedError
\ No newline at end of file
vllm/distributed/device_communicators/cpu_communicator.py
View file @
eefa41c1
...
...
@@ -151,29 +151,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
def
dispatch
(
# type: ignore[override]
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
,
)
return
hidden_states
class
_CPUSHMDistributed
:
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
eefa41c1
...
...
@@ -396,7 +396,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return
output_list
def
dispatch
(
# type: ignore[override]
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -406,20 +406,54 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
,
)
def
batch_isend_irecv
(
self
,
p2p_ops
:
list
):
...
...
@@ -427,4 +461,4 @@ class CudaCommunicator(DeviceCommunicatorBase):
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
batch_isend_irecv
(
p2p_ops
)
else
:
raise
ValueError
(
"No PyNCCL communicator found"
)
raise
ValueError
(
"No PyNCCL communicator found"
)
\ No newline at end of file
vllm/distributed/device_communicators/mnnvl_compat.py
View file @
eefa41c1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch.distributed
as
dist
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
...
...
vllm/distributed/device_communicators/xpu_communicator.py
View file @
eefa41c1
...
...
@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
def
dispatch
(
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
)
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
)
return
hidden_states
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
,
)
\ No newline at end of file
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
eefa41c1
...
...
@@ -384,7 +384,9 @@ class TpKVTopology:
@
property
def
split_k_and_v
(
self
)
->
bool
:
# Whether to register regions for K and V separately (when present).
return
not
(
self
.
is_mla
or
self
.
is_kv_layout_blocks_first
)
return
not
(
self
.
_cross_layers_blocks
or
self
.
is_mla
or
self
.
is_kv_layout_blocks_first
)
@
property
def
tp_size
(
self
)
->
int
:
...
...
@@ -554,4 +556,4 @@ def get_current_attn_backend(
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
)
->
type
[
AttentionBackend
]:
"""Get the first attention backend for the given layers."""
return
get_current_attn_backends
(
vllm_config
,
layer_names
)[
0
]
return
get_current_attn_backends
(
vllm_config
,
layer_names
)[
0
]
\ No newline at end of file
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
eefa41c1
...
...
@@ -56,7 +56,7 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backend
import
AttentionBackend
,
AttentionMetadata
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
...
...
@@ -186,7 +186,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def
compute_nixl_compatibility_hash
(
vllm_config
:
VllmConfig
,
attn_backend_name
:
str
vllm_config
:
VllmConfig
,
attn_backend_name
:
str
,
cross_layers_blocks
:
bool
)
->
str
:
"""
Compute compatibility hash for NIXL KV transfer.
...
...
@@ -1164,12 +1164,9 @@ class NixlConnectorWorker:
logger
.
info
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
info
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
self
.
compat_hash
=
compute_nixl_compatibility_hash
(
self
.
vllm_config
,
self
.
backend_name
)
self
.
enforce_compat_hash
=
self
.
kv_transfer_config
.
get_from_extra_config
(
"enforce_handshake_compat"
,
True
)
# lazy initialized in register_kv_caches
self
.
compat_hash
:
str
|
None
=
None
self
.
kv_topo
:
TpKVTopology
|
None
=
None
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_block_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
block_size
}
...
...
@@ -1184,7 +1181,6 @@ class NixlConnectorWorker:
self
.
enforce_compat_hash
=
self
.
kv_transfer_config
.
get_from_extra_config
(
"enforce_handshake_compat"
,
True
)
self
.
_physical_blocks_per_logical_kv_block
=
1
def
_sync_block_size_with_kernel
(
self
)
->
None
:
backends
=
get_current_attn_backends
(
self
.
vllm_config
)
...
...
@@ -1232,6 +1228,7 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
assert
self
.
kv_topo
is
not
None
p_remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks
(
remote_tp_size
)
remote_rank_to_agent_name
=
{}
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
...
...
@@ -1269,6 +1266,7 @@ class NixlConnectorWorker:
)
# Check compatibility hash BEFORE decoding agent metadata
assert
self
.
compat_hash
is
not
None
if
(
self
.
enforce_compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
...
...
@@ -1547,7 +1545,6 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v
=
self
.
kv_topo
.
split_k_and_v
tensor_size_bytes
=
None
# Enable different block lengths for different layers *only* when MLA is used.
...
...
@@ -1698,6 +1695,7 @@ class NixlConnectorWorker:
ssm_sizes
=
self
.
_mamba_ssm_size
,
)
# Wrap metadata in payload with hash for defensive decoding
assert
self
.
compat_hash
is
not
None
encoder
=
msgspec
.
msgpack
.
Encoder
()
self
.
xfer_handshake_metadata
=
NixlHandshakePayload
(
compatibility_hash
=
self
.
compat_hash
,
...
...
@@ -2177,6 +2175,7 @@ class NixlConnectorWorker:
if
len
(
self
.
device_kv_caches
)
==
0
:
return
assert
block_size_ratio
>=
1
,
"Only nP < nD supported currently."
assert
self
.
kv_topo
is
not
None
if
self
.
enable_permute_local_kv
and
block_size_ratio
>
1
:
logger
.
debug
(
"Post-processing device kv cache on receive by converting "
...
...
@@ -2196,7 +2195,7 @@ class NixlConnectorWorker:
block_size_ratio
,
)
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
kv_topo
.
is_kv_layout_blocks_first
)
split_k_and_v
=
self
.
kv_topo
.
split_k_and_v
for
block_ids
in
block_ids_list
:
indices
=
torch
.
tensor
(
block_ids
,
device
=
self
.
device_type
,
dtype
=
torch
.
long
)
...
...
@@ -2221,6 +2220,7 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
assert
self
.
kv_topo
is
not
None
done_sending
=
self
.
_get_new_notifs
()
done_recving
=
self
.
_pop_done_transfers
(
self
.
_recving_transfers
)
...
...
@@ -2291,6 +2291,7 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
assert
self
.
kv_topo
is
not
None
notified_req_ids
:
set
[
str
]
=
set
()
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notif
in
notifs
:
...
...
@@ -2451,7 +2452,7 @@ class NixlConnectorWorker:
self
.
_reqs_to_send
[
req_id
]
=
expiration_time
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
assert
meta
.
remote
is
not
None
assert
meta
.
remote
is
not
None
and
self
.
kv_topo
is
not
None
remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks_from_engine_id
(
meta
.
remote
.
engine_id
)
...
...
@@ -2782,6 +2783,7 @@ class NixlConnectorWorker:
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
"""
assert
self
.
kv_topo
is
not
None
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# For indexing only half (either just the K or V part).
if
mamba_view
:
...
...
@@ -3103,4 +3105,4 @@ class NixlPromMetrics(KVConnectorPromMetrics):
[
"num_failed_transfers"
,
"num_failed_notifications"
,
"num_kv_expired_reqs"
],
):
for
list_item
in
transfer_stats_data
[
counter_item_key
]:
counter_obj
[
engine_idx
].
inc
(
list_item
)
counter_obj
[
engine_idx
].
inc
(
list_item
)
\ No newline at end of file
vllm/distributed/parallel_state.py
View file @
eefa41c1
...
...
@@ -1065,7 +1065,7 @@ class GroupCoordinator:
if
self
.
device_communicator
is
not
None
:
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
def
dispatch
(
def
dispatch
_router_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -1076,7 +1076,7 @@ class GroupCoordinator:
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
# type: ignore[call-arg]
return
self
.
device_communicator
.
dispatch
_router_logits
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
...
...
@@ -1085,6 +1085,28 @@ class GroupCoordinator:
else
:
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
):
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
,
)
else
:
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
self
,
hidden_states
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -2090,4 +2112,4 @@ def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
if
is_same_node
and
node_assignment
[
other_rank
]
==
0
:
node_assignment
[
other_rank
]
=
next_node_id
return
next_node_id
return
next_node_id
\ No newline at end of file
Prev
1
2
3
4
5
6
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment