Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eefa41c1
Commit
eefa41c1
authored
Mar 24, 2026
by
zhuwenwen
Browse files
sync v0.18.0
parent
82155c76
Changes
253
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1046 additions
and
128 deletions
+1046
-128
tests/models/registry.py
tests/models/registry.py
+15
-13
tests/models/test_transformers.py
tests/models/test_transformers.py
+2
-2
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+2
-2
tests/test_access_log_filter.py
tests/test_access_log_filter.py
+371
-0
tests/v1/e2e/spec_decode/test_spec_decode.py
tests/v1/e2e/spec_decode/test_spec_decode.py
+2
-2
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+3
-3
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+15
-46
vllm/_custom_ops.py
vllm/_custom_ops.py
+26
-11
vllm/_xpu_ops.py
vllm/_xpu_ops.py
+297
-1
vllm/config/compilation.py
vllm/config/compilation.py
+2
-1
vllm/config/speculative.py
vllm/config/speculative.py
+10
-0
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+84
-3
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+44
-5
vllm/distributed/device_communicators/cpu_communicator.py
vllm/distributed/device_communicators/cpu_communicator.py
+43
-7
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+40
-6
vllm/distributed/device_communicators/mnnvl_compat.py
vllm/distributed/device_communicators/mnnvl_compat.py
+2
-0
vllm/distributed/device_communicators/xpu_communicator.py
vllm/distributed/device_communicators/xpu_communicator.py
+44
-8
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+4
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+15
-13
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+25
-3
No files found.
tests/models/registry.py
View file @
eefa41c1
...
@@ -266,7 +266,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -266,7 +266,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
),
),
"Exaone4ForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/EXAONE-4.0-32B"
),
"Exaone4ForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/EXAONE-4.0-32B"
),
"ExaoneMoEForCausalLM"
:
_HfExamplesInfo
(
"ExaoneMoEForCausalLM"
:
_HfExamplesInfo
(
"LGAI-EXAONE/K-EXAONE-236B-A23B"
,
min_transformers_version
=
"5.
0
.0"
"LGAI-EXAONE/K-EXAONE-236B-A23B"
,
min_transformers_version
=
"5.
1
.0"
),
),
"Fairseq2LlamaForCausalLM"
:
_HfExamplesInfo
(
"mgleize/fairseq2-dummy-Llama-3.2-1B"
),
"Fairseq2LlamaForCausalLM"
:
_HfExamplesInfo
(
"mgleize/fairseq2-dummy-Llama-3.2-1B"
),
"FalconForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-7b"
),
"FalconForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/falcon-7b"
),
...
@@ -283,11 +283,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -283,11 +283,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Glm4MoeForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-4.5"
),
"Glm4MoeForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-4.5"
),
"Glm4MoeLiteForCausalLM"
:
_HfExamplesInfo
(
"Glm4MoeLiteForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-4.7-Flash"
,
"zai-org/GLM-4.7-Flash"
,
min_transformers_version
=
"5.0.0.dev"
,
min_transformers_version
=
"5.0.0"
,
is_available_online
=
False
,
),
"GlmMoeDsaForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-5"
,
min_transformers_version
=
"5.0.1"
,
is_available_online
=
False
),
),
"GlmMoeDsaForCausalLM"
:
_HfExamplesInfo
(
"GlmMoeDsaForCausalLM"
:
_HfExamplesInfo
(
"zai-org/GLM-5"
,
min_transformers_version
=
"5.0.1"
,
is_available_online
=
False
"zai-org/GLM-5"
,
min_transformers_version
=
"5.0.1"
,
is_available_online
=
False
...
@@ -743,7 +739,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -743,7 +739,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
# [Decoder-only]
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
),
"AriaForConditionalGeneration"
:
_HfExamplesInfo
(
"rhymes-ai/Aria"
),
"AudioFlamingo3ForConditionalGeneration"
:
_HfExamplesInfo
(
"AudioFlamingo3ForConditionalGeneration"
:
_HfExamplesInfo
(
"nvidia/audio-flamingo-3-hf"
,
min_transformers_version
=
"5.0.0
.dev
"
"nvidia/audio-flamingo-3-hf"
,
min_transformers_version
=
"5.0.0"
),
),
"MusicFlamingoForConditionalGeneration"
:
_HfExamplesInfo
(
"MusicFlamingoForConditionalGeneration"
:
_HfExamplesInfo
(
"nvidia/music-flamingo-2601-hf"
,
min_transformers_version
=
"5.0.0.dev"
"nvidia/music-flamingo-2601-hf"
,
min_transformers_version
=
"5.0.0.dev"
...
@@ -1237,7 +1233,13 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
...
@@ -1237,7 +1233,13 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeLiteMTPModel"
:
_HfExamplesInfo
(
"Glm4MoeLiteMTPModel"
:
_HfExamplesInfo
(
"zai-org/GLM-4.7-Flash"
,
"zai-org/GLM-4.7-Flash"
,
speculative_model
=
"zai-org/GLM-4.7-Flash"
,
speculative_model
=
"zai-org/GLM-4.7-Flash"
,
min_transformers_version
=
"5.0.0"
,
),
"GlmOcrMTPModel"
:
_HfExamplesInfo
(
"zai-org/GLM-OCR"
,
speculative_model
=
"zai-org/GLM-OCR"
,
is_available_online
=
False
,
is_available_online
=
False
,
min_transformers_version
=
"5.1.0"
,
),
),
"LongCatFlashMTPModel"
:
_HfExamplesInfo
(
"LongCatFlashMTPModel"
:
_HfExamplesInfo
(
"meituan-longcat/LongCat-Flash-Chat"
,
"meituan-longcat/LongCat-Flash-Chat"
,
...
@@ -1282,27 +1284,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
...
@@ -1282,27 +1284,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_BACKEND_MODELS
=
{
_TRANSFORMERS_BACKEND_MODELS
=
{
"TransformersEmbeddingModel"
:
_HfExamplesInfo
(
"TransformersEmbeddingModel"
:
_HfExamplesInfo
(
"BAAI/bge-base-en-v1.5"
,
min_transformers_version
=
"5.0.0
.dev
"
"BAAI/bge-base-en-v1.5"
,
min_transformers_version
=
"5.0.0"
),
),
"TransformersForSequenceClassification"
:
_HfExamplesInfo
(
"TransformersForSequenceClassification"
:
_HfExamplesInfo
(
"papluca/xlm-roberta-base-language-detection"
,
"papluca/xlm-roberta-base-language-detection"
,
min_transformers_version
=
"5.0.0
.dev
"
,
min_transformers_version
=
"5.0.0"
,
),
),
"TransformersForCausalLM"
:
_HfExamplesInfo
(
"TransformersForCausalLM"
:
_HfExamplesInfo
(
"hmellor/Ilama-3.2-1B"
,
trust_remote_code
=
True
"hmellor/Ilama-3.2-1B"
,
trust_remote_code
=
True
),
),
"TransformersMultiModalForCausalLM"
:
_HfExamplesInfo
(
"BAAI/Emu3-Chat-hf"
),
"TransformersMultiModalForCausalLM"
:
_HfExamplesInfo
(
"BAAI/Emu3-Chat-hf"
),
"TransformersMoEForCausalLM"
:
_HfExamplesInfo
(
"TransformersMoEForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924"
,
min_transformers_version
=
"5.0.0
.dev
"
"allenai/OLMoE-1B-7B-0924"
,
min_transformers_version
=
"5.0.0"
),
),
"TransformersMultiModalMoEForCausalLM"
:
_HfExamplesInfo
(
"TransformersMultiModalMoEForCausalLM"
:
_HfExamplesInfo
(
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
min_transformers_version
=
"5.0.0
.dev
"
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
min_transformers_version
=
"5.0.0"
),
),
"TransformersMoEEmbeddingModel"
:
_HfExamplesInfo
(
"TransformersMoEEmbeddingModel"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
,
min_transformers_version
=
"5.0.0
.dev
"
"Qwen/Qwen3-30B-A3B"
,
min_transformers_version
=
"5.0.0"
),
),
"TransformersMoEForSequenceClassification"
:
_HfExamplesInfo
(
"TransformersMoEForSequenceClassification"
:
_HfExamplesInfo
(
"Qwen/Qwen3-30B-A3B"
,
min_transformers_version
=
"5.0.0
.dev
"
"Qwen/Qwen3-30B-A3B"
,
min_transformers_version
=
"5.0.0"
),
),
"TransformersMultiModalEmbeddingModel"
:
_HfExamplesInfo
(
"google/gemma-3-4b-it"
),
"TransformersMultiModalEmbeddingModel"
:
_HfExamplesInfo
(
"google/gemma-3-4b-it"
),
"TransformersMultiModalForSequenceClassification"
:
_HfExamplesInfo
(
"TransformersMultiModalForSequenceClassification"
:
_HfExamplesInfo
(
...
...
tests/models/test_transformers.py
View file @
eefa41c1
...
@@ -76,7 +76,7 @@ def test_models(
...
@@ -76,7 +76,7 @@ def test_models(
from
packaging.version
import
Version
from
packaging.version
import
Version
installed
=
Version
(
transformers
.
__version__
)
installed
=
Version
(
transformers
.
__version__
)
required
=
Version
(
"5.0.0
.dev
"
)
required
=
Version
(
"5.0.0"
)
if
model
==
"allenai/OLMoE-1B-7B-0924"
and
installed
<
required
:
if
model
==
"allenai/OLMoE-1B-7B-0924"
and
installed
<
required
:
pytest
.
skip
(
pytest
.
skip
(
"MoE models with the Transformers modeling backend require "
"MoE models with the Transformers modeling backend require "
...
@@ -237,4 +237,4 @@ def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
...
@@ -237,4 +237,4 @@ def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
embeddings_1_lst
=
vllm_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm"
,
)
)
\ No newline at end of file
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
eefa41c1
...
@@ -36,7 +36,7 @@ class MyGemma2Embedding(nn.Module):
...
@@ -36,7 +36,7 @@ class MyGemma2Embedding(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
@@ -59,4 +59,4 @@ class MyGemma2Embedding(nn.Module):
...
@@ -59,4 +59,4 @@ class MyGemma2Embedding(nn.Module):
weights
=
(
weights
=
(
(
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
)
(
name
,
data
)
for
name
,
data
in
weights
if
not
name
.
startswith
(
"lm_head."
)
)
)
return
self
.
model
.
load_weights
(
weights
)
return
self
.
model
.
load_weights
(
weights
)
\ No newline at end of file
tests/test_access_log_filter.py
0 → 100644
View file @
eefa41c1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the UvicornAccessLogFilter class.
"""
import
logging
from
vllm.logging_utils.access_log_filter
import
(
UvicornAccessLogFilter
,
create_uvicorn_log_config
,
)
class
TestUvicornAccessLogFilter
:
"""Test cases for UvicornAccessLogFilter."""
def
test_filter_allows_all_when_no_excluded_paths
(
self
):
"""Filter should allow all logs when no paths are excluded."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/v1/completions"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
True
def
test_filter_allows_all_when_excluded_paths_is_none
(
self
):
"""Filter should allow all logs when excluded_paths is None."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
None
)
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
True
def
test_filter_excludes_health_endpoint
(
self
):
"""Filter should exclude /health endpoint when configured."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
False
def
test_filter_excludes_metrics_endpoint
(
self
):
"""Filter should exclude /metrics endpoint when configured."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/metrics"
])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/metrics"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
False
def
test_filter_allows_non_excluded_endpoints
(
self
):
"""Filter should allow endpoints not in the excluded list."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
,
"/metrics"
])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"POST"
,
"/v1/completions"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
True
def
test_filter_excludes_multiple_endpoints
(
self
):
"""Filter should exclude multiple configured endpoints."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
,
"/metrics"
,
"/ping"
])
# Test /health
record_health
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_health
)
is
False
# Test /metrics
record_metrics
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/metrics"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_metrics
)
is
False
# Test /ping
record_ping
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/ping"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_ping
)
is
False
def
test_filter_with_query_parameters
(
self
):
"""Filter should exclude endpoints even with query parameters."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health?verbose=true"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
False
def
test_filter_different_http_methods
(
self
):
"""Filter should exclude endpoints regardless of HTTP method."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/ping"
])
# Test GET
record_get
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/ping"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_get
)
is
False
# Test POST
record_post
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"POST"
,
"/ping"
,
"1.1"
,
200
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record_post
)
is
False
def
test_filter_with_different_status_codes
(
self
):
"""Filter should exclude endpoints regardless of status code."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
for
status_code
in
[
200
,
500
,
503
]:
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
'%s - "%s %s HTTP/%s" %d'
,
args
=
(
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
status_code
),
exc_info
=
None
,
)
assert
filter
.
filter
(
record
)
is
False
class
TestCreateUvicornLogConfig
:
"""Test cases for create_uvicorn_log_config function."""
def
test_creates_valid_config_structure
(
self
):
"""Config should have required logging configuration keys."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[
"/health"
])
assert
"version"
in
config
assert
config
[
"version"
]
==
1
assert
"disable_existing_loggers"
in
config
assert
"formatters"
in
config
assert
"handlers"
in
config
assert
"loggers"
in
config
assert
"filters"
in
config
def
test_config_includes_access_log_filter
(
self
):
"""Config should include the access log filter."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[
"/health"
,
"/metrics"
])
assert
"access_log_filter"
in
config
[
"filters"
]
filter_config
=
config
[
"filters"
][
"access_log_filter"
]
assert
filter_config
[
"()"
]
==
UvicornAccessLogFilter
assert
filter_config
[
"excluded_paths"
]
==
[
"/health"
,
"/metrics"
]
def
test_config_applies_filter_to_access_handler
(
self
):
"""Config should apply the filter to the access handler."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[
"/health"
])
assert
"access"
in
config
[
"handlers"
]
assert
"filters"
in
config
[
"handlers"
][
"access"
]
assert
"access_log_filter"
in
config
[
"handlers"
][
"access"
][
"filters"
]
def
test_config_with_custom_log_level
(
self
):
"""Config should respect custom log level."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[
"/health"
],
log_level
=
"debug"
)
assert
config
[
"loggers"
][
"uvicorn"
][
"level"
]
==
"DEBUG"
assert
config
[
"loggers"
][
"uvicorn.access"
][
"level"
]
==
"DEBUG"
assert
config
[
"loggers"
][
"uvicorn.error"
][
"level"
]
==
"DEBUG"
def
test_config_with_empty_excluded_paths
(
self
):
"""Config should work with empty excluded paths."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
[])
assert
config
[
"filters"
][
"access_log_filter"
][
"excluded_paths"
]
==
[]
def
test_config_with_none_excluded_paths
(
self
):
"""Config should work with None excluded paths."""
config
=
create_uvicorn_log_config
(
excluded_paths
=
None
)
assert
config
[
"filters"
][
"access_log_filter"
][
"excluded_paths"
]
==
[]
class
TestIntegration
:
"""Integration tests for the access log filter."""
def
test_filter_with_real_logger
(
self
):
"""Test filter works with a real Python logger simulating uvicorn."""
# Create a logger with our filter (simulating uvicorn.access)
logger
=
logging
.
getLogger
(
"uvicorn.access"
)
logger
.
setLevel
(
logging
.
INFO
)
# Clear any existing handlers
logger
.
handlers
=
[]
# Create a custom handler that tracks messages
logged_messages
:
list
[
str
]
=
[]
class
TrackingHandler
(
logging
.
Handler
):
def
emit
(
self
,
record
):
logged_messages
.
append
(
record
.
getMessage
())
handler
=
TrackingHandler
()
handler
.
setLevel
(
logging
.
INFO
)
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
,
"/metrics"
])
handler
.
addFilter
(
filter
)
logger
.
addHandler
(
handler
)
# Log using uvicorn's format with args tuple
# Format: '%s - "%s %s HTTP/%s" %d'
logger
.
info
(
'%s - "%s %s HTTP/%s" %d'
,
"127.0.0.1:12345"
,
"GET"
,
"/health"
,
"1.1"
,
200
,
)
logger
.
info
(
'%s - "%s %s HTTP/%s" %d'
,
"127.0.0.1:12345"
,
"GET"
,
"/v1/completions"
,
"1.1"
,
200
,
)
logger
.
info
(
'%s - "%s %s HTTP/%s" %d'
,
"127.0.0.1:12345"
,
"GET"
,
"/metrics"
,
"1.1"
,
200
,
)
logger
.
info
(
'%s - "%s %s HTTP/%s" %d'
,
"127.0.0.1:12345"
,
"POST"
,
"/v1/chat/completions"
,
"1.1"
,
200
,
)
# Verify only non-excluded endpoints were logged
assert
len
(
logged_messages
)
==
2
assert
"/v1/completions"
in
logged_messages
[
0
]
assert
"/v1/chat/completions"
in
logged_messages
[
1
]
def
test_filter_allows_non_uvicorn_access_logs
(
self
):
"""Test filter allows logs from non-uvicorn.access loggers."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
# Log record from a different logger name
record
=
logging
.
LogRecord
(
name
=
"uvicorn.error"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
"Some error message about /health"
,
args
=
(),
exc_info
=
None
,
)
# Should allow because it's not from uvicorn.access
assert
filter
.
filter
(
record
)
is
True
def
test_filter_handles_malformed_args
(
self
):
"""Test filter handles log records with unexpected args format."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
# Log record with insufficient args
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
"Some message"
,
args
=
(
"only"
,
"two"
),
exc_info
=
None
,
)
# Should allow because args doesn't have expected format
assert
filter
.
filter
(
record
)
is
True
def
test_filter_handles_non_tuple_args
(
self
):
"""Test filter handles log records with non-tuple args."""
filter
=
UvicornAccessLogFilter
(
excluded_paths
=
[
"/health"
])
# Log record with None args
record
=
logging
.
LogRecord
(
name
=
"uvicorn.access"
,
level
=
logging
.
INFO
,
pathname
=
""
,
lineno
=
0
,
msg
=
"Some message without args"
,
args
=
None
,
exc_info
=
None
,
)
# Should allow because args is None
assert
filter
.
filter
(
record
)
is
True
\ No newline at end of file
tests/v1/e2e/spec_decode/test_spec_decode.py
View file @
eefa41c1
...
@@ -383,7 +383,7 @@ def _run_eagle_correctness(
...
@@ -383,7 +383,7 @@ def _run_eagle_correctness(
from
packaging.version
import
Version
from
packaging.version
import
Version
installed
=
Version
(
transformers
.
__version__
)
installed
=
Version
(
transformers
.
__version__
)
required
=
Version
(
"5.0.0
.dev
"
)
required
=
Version
(
"5.0.0"
)
if
installed
<
required
:
if
installed
<
required
:
pytest
.
skip
(
pytest
.
skip
(
"Eagle3 with the Transformers modeling backend requires "
"Eagle3 with the Transformers modeling backend requires "
...
@@ -1030,4 +1030,4 @@ def compute_acceptance_len(metrics: list[Metric]) -> float:
...
@@ -1030,4 +1030,4 @@ def compute_acceptance_len(metrics: list[Metric]) -> float:
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
if
n_drafts
==
0
:
if
n_drafts
==
0
:
return
1
return
1
return
1
+
(
n_accepted_toks
/
n_drafts
)
return
1
+
(
n_accepted_toks
/
n_drafts
)
\ No newline at end of file
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
eefa41c1
...
@@ -59,9 +59,9 @@ fi
...
@@ -59,9 +59,9 @@ fi
# Build the kv-transfer-config once
# Build the kv-transfer-config once
if
[[
"
$KV_BUFFER_DEVICE
"
==
"cuda"
]]
;
then
if
[[
"
$KV_BUFFER_DEVICE
"
==
"cuda"
]]
;
then
KV_CONFIG
=
'{"kv_connector":"NixlConnector","kv_role":"kv_both"'
${
KV_CONFIG_HETERO_LAYOUT
}
'}'
KV_CONFIG
=
'{"kv_connector":"NixlConnector","kv_role":"kv_both"'
${
KV_CONFIG_HETERO_LAYOUT
}
${
KV_EXTRA_CONFIG
}
'}'
else
else
KV_CONFIG
=
"{
\"
kv_connector
\"
:
\"
NixlConnector
\"
,
\"
kv_role
\"
:
\"
kv_both
\"
,
\"
kv_buffer_device
\"
:
\"
$KV_BUFFER_DEVICE
\"
"
${
KV_CONFIG_HETERO_LAYOUT
}
"}"
KV_CONFIG
=
"{
\"
kv_connector
\"
:
\"
NixlConnector
\"
,
\"
kv_role
\"
:
\"
kv_both
\"
,
\"
kv_buffer_device
\"
:
\"
$KV_BUFFER_DEVICE
\"
"
${
KV_CONFIG_HETERO_LAYOUT
}
${
KV_EXTRA_CONFIG
}
"}"
fi
fi
# Models to run
# Models to run
...
@@ -295,4 +295,4 @@ for model in "${MODELS[@]}"; do
...
@@ -295,4 +295,4 @@ for model in "${MODELS[@]}"; do
run_tests_for_model
"
$model
"
run_tests_for_model
"
$model
"
done
done
echo
"All tests completed!"
echo
"All tests completed!"
\ No newline at end of file
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
eefa41c1
...
@@ -18,8 +18,12 @@ import ray
...
@@ -18,8 +18,12 @@ import ray
import
torch
import
torch
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
from
vllm.config
import
KVTransferConfig
,
set_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
KVOutputAggregator
,
TpKVTopology
,
get_current_attn_backend
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
nixl_connector
from
vllm.distributed.kv_transfer.kv_connector.v1
import
nixl_connector
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.distributed.kv_transfer.kv_connector.v1.multi_connector
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.multi_connector
import
(
...
@@ -58,6 +62,8 @@ from vllm.v1.kv_cache_interface import (
...
@@ -58,6 +62,8 @@ from vllm.v1.kv_cache_interface import (
)
)
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.utils
import
AttentionGroup
from
.utils
import
(
from
.utils
import
(
create_request
,
create_request
,
...
@@ -1498,44 +1504,6 @@ def test_register_kv_caches(
...
@@ -1498,44 +1504,6 @@ def test_register_kv_caches(
backend_cls
=
TritonAttentionBackend
backend_cls
=
TritonAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
}
# Store tensor info for validation
test_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
)
is_blocks_first
=
len
(
test_shape
)
==
5
and
test_shape
[
0
]
==
1
if
is_blocks_first
:
expected_tensor_size
=
shared_tensor
.
element_size
()
*
shared_tensor
.
numel
()
expected_base_addrs
=
[
shared_tensor
.
data_ptr
(),
unique_tensor
.
data_ptr
(),
]
expected_num_entries
=
2
else
:
expected_tensor_size
=
(
shared_tensor
[
0
].
element_size
()
*
shared_tensor
[
0
].
numel
()
)
expected_base_addrs
=
[
shared_tensor
[
0
].
data_ptr
(),
shared_tensor
[
1
].
data_ptr
(),
unique_tensor
[
0
].
data_ptr
(),
unique_tensor
[
1
].
data_ptr
(),
]
expected_num_entries
=
4
nixl_module
=
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
nixl_module
=
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with
(
with
(
patch
(
f
"
{
nixl_module
}
.NixlWrapper"
)
as
mock_nixl_wrapper
,
patch
(
f
"
{
nixl_module
}
.NixlWrapper"
)
as
mock_nixl_wrapper
,
...
@@ -1716,14 +1684,13 @@ def test_register_kv_caches(
...
@@ -1716,14 +1684,13 @@ def test_register_kv_caches(
blocks_data
,
_
=
mock_wrapper_instance
.
get_xfer_descs
.
call_args
[
0
]
blocks_data
,
_
=
mock_wrapper_instance
.
get_xfer_descs
.
call_args
[
0
]
# Validate blocks_data structure and size
# Validate blocks_data structure and size
expected_blocks_count
=
8
assert
len
(
blocks_data
)
==
expected_blocks_count
,
(
assert
len
(
blocks_data
)
==
expected_blocks_count
,
(
f
"Expected
{
expected_blocks_count
}
blocks, got
{
len
(
blocks_data
)
}
"
f
"Expected
{
expected_blocks_count
}
blocks, got
{
len
(
blocks_data
)
}
"
)
)
num
_blocks
=
2
if
connector
.
prefer_cross_layer
_blocks
:
if
is_blocks_first
:
num_blocks
=
8
expected_block_len
=
expected_tensor_size
//
num_blocks
//
2
expected_block_len
=
expected_tensor_size
//
num_blocks
else
:
else
:
num_blocks
=
kv_cache_config
.
num_blocks
num_blocks
=
kv_cache_config
.
num_blocks
if
is_blocks_first
:
if
is_blocks_first
:
...
@@ -2360,7 +2327,9 @@ def test_compatibility_hash_validation(
...
@@ -2360,7 +2327,9 @@ def test_compatibility_hash_validation(
)
)
)
)
remote_hash
=
compute_nixl_compatibility_hash
(
remote_hash
=
compute_nixl_compatibility_hash
(
remote_vllm_config
,
decode_worker
.
backend_name
remote_vllm_config
,
decode_worker
.
backend_name
,
decode_worker
.
kv_topo
.
cross_layers_blocks
,
)
)
prefill_block_size
=
config_overrides
.
get
(
"block_size"
,
16
)
prefill_block_size
=
config_overrides
.
get
(
"block_size"
,
16
)
...
@@ -2497,4 +2466,4 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
...
@@ -2497,4 +2466,4 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
port
=
1234
,
port
=
1234
,
remote_tp_size
=
1
,
remote_tp_size
=
1
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
)
)
\ No newline at end of file
vllm/_custom_ops.py
View file @
eefa41c1
...
@@ -3044,13 +3044,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
...
@@ -3044,13 +3044,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
class
CPUDNNLGEMMHandler
:
class
CPUDNNLGEMMHandler
:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
handler
:
int
|
None
=
None
self
.
handler
_tensor
:
torch
.
Tensor
|
None
=
None
self
.
n
=
-
1
self
.
n
=
-
1
self
.
k
=
-
1
self
.
k
=
-
1
def
__del__
(
self
):
def
__del__
(
self
):
if
self
.
handler
is
not
None
:
if
self
.
handler
_tensor
is
not
None
:
torch
.
ops
.
_C
.
release_dnnl_matmul_handler
(
self
.
handler
)
torch
.
ops
.
_C
.
release_dnnl_matmul_handler
(
self
.
handler
_tensor
.
item
()
)
_supports_onednn
=
bool
(
hasattr
(
torch
.
ops
.
_C
,
"create_onednn_mm_handler"
))
_supports_onednn
=
bool
(
hasattr
(
torch
.
ops
.
_C
,
"create_onednn_mm_handler"
))
...
@@ -3066,8 +3066,10 @@ def create_onednn_mm(
...
@@ -3066,8 +3066,10 @@ def create_onednn_mm(
)
->
CPUDNNLGEMMHandler
:
)
->
CPUDNNLGEMMHandler
:
handler
=
CPUDNNLGEMMHandler
()
handler
=
CPUDNNLGEMMHandler
()
handler
.
k
,
handler
.
n
=
weight
.
size
()
handler
.
k
,
handler
.
n
=
weight
.
size
()
handler
.
handler
=
torch
.
ops
.
_C
.
create_onednn_mm_handler
(
# store the handler pointer in a tensor it doesn't get inlined
weight
,
primitive_cache_size
handler
.
handler_tensor
=
torch
.
tensor
(
torch
.
ops
.
_C
.
create_onednn_mm_handler
(
weight
,
primitive_cache_size
),
dtype
=
torch
.
int64
,
)
)
return
handler
return
handler
...
@@ -3079,7 +3081,7 @@ def onednn_mm(
...
@@ -3079,7 +3081,7 @@ def onednn_mm(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty
((
*
x
.
shape
[
0
:
-
1
],
dnnl_handler
.
n
),
dtype
=
x
.
dtype
)
output
=
torch
.
empty
((
*
x
.
shape
[
0
:
-
1
],
dnnl_handler
.
n
),
dtype
=
x
.
dtype
)
torch
.
ops
.
_C
.
onednn_mm
(
torch
.
ops
.
_C
.
onednn_mm
(
output
,
x
.
reshape
(
-
1
,
dnnl_handler
.
k
),
bias
,
dnnl_handler
.
handler
output
,
x
.
reshape
(
-
1
,
dnnl_handler
.
k
),
bias
,
dnnl_handler
.
handler
_tensor
)
)
return
output
return
output
...
@@ -3095,8 +3097,17 @@ def create_onednn_scaled_mm(
...
@@ -3095,8 +3097,17 @@ def create_onednn_scaled_mm(
)
->
CPUDNNLGEMMHandler
:
)
->
CPUDNNLGEMMHandler
:
handler
=
CPUDNNLGEMMHandler
()
handler
=
CPUDNNLGEMMHandler
()
handler
.
k
,
handler
.
n
=
weight
.
size
()
handler
.
k
,
handler
.
n
=
weight
.
size
()
handler
.
handler
=
torch
.
ops
.
_C
.
create_onednn_scaled_mm_handler
(
# store the handler pointer in a tensor so it doesn't get inlined
weight
,
weight_scales
,
output_type
,
dynamic_quant
,
use_azp
,
primitive_cache_size
handler
.
handler_tensor
=
torch
.
tensor
(
torch
.
ops
.
_C
.
create_onednn_scaled_mm_handler
(
weight
,
weight_scales
,
output_type
,
dynamic_quant
,
use_azp
,
primitive_cache_size
,
),
dtype
=
torch
.
int64
,
)
)
return
handler
return
handler
...
@@ -3149,11 +3160,15 @@ def onednn_scaled_mm(
...
@@ -3149,11 +3160,15 @@ def onednn_scaled_mm(
bias
:
torch
.
Tensor
|
None
,
bias
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
torch
.
ops
.
_C
.
onednn_scaled_mm
(
torch
.
ops
.
_C
.
onednn_scaled_mm
(
output
,
x
,
input_scale
,
input_zp
,
input_zp_adj
,
bias
,
dnnl_handler
.
handler
output
,
x
,
input_scale
,
input_zp
,
input_zp_adj
,
bias
,
dnnl_handler
.
handler_tensor
,
)
)
return
output
def
cpu_attn_get_scheduler_metadata
(
def
cpu_attn_get_scheduler_metadata
(
num_reqs
:
int
,
num_reqs
:
int
,
...
...
vllm/_xpu_ops.py
View file @
eefa41c1
...
@@ -7,6 +7,8 @@ import torch
...
@@ -7,6 +7,8 @@ import torch
from
vllm_xpu_kernels.flash_attn_interface
import
flash_attn_varlen_func
from
vllm_xpu_kernels.flash_attn_interface
import
flash_attn_varlen_func
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -53,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
...
@@ -53,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
return
torch
.
empty
((
M
,
N
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
return
torch
.
empty
((
M
,
N
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
_xpu_ops_deepseek_scaling_rope_impl
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
|
None
,
offsets
:
torch
.
Tensor
|
None
,
cos_sin_cache
:
torch
.
Tensor
|
None
,
rotary_dim
:
int
,
is_neox_style
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
key
is
not
None
return
torch
.
ops
.
_xpu_C
.
deepseek_scaling_rope
(
positions
,
query
,
key
,
offsets
,
cos_sin_cache
,
rotary_dim
,
is_neox_style
)
def
_xpu_ops_deepseek_scaling_rope_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
|
None
,
offsets
:
torch
.
Tensor
|
None
,
cos_sin_cache
:
torch
.
Tensor
|
None
,
rotary_dim
:
int
,
is_neox_style
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
query
,
key
# Global flag to ensure ops are registered only once
_OPS_REGISTERED
=
False
class
xpu_ops
:
class
xpu_ops
:
@
staticmethod
@
staticmethod
def
flash_attn_varlen_func
(
def
flash_attn_varlen_func
(
...
@@ -105,9 +138,10 @@ class xpu_ops:
...
@@ -105,9 +138,10 @@ class xpu_ops:
assert
len
(
window_size
)
==
2
assert
len
(
window_size
)
==
2
real_window_size
=
(
window_size
[
0
],
window_size
[
1
])
# noqa: F841
real_window_size
=
(
window_size
[
0
],
window_size
[
1
])
# noqa: F841
# In encode attention, v maybe not contiguous and current
# In encode attention,
k and
v maybe not contiguous and current
# kernel can't handle it
# kernel can't handle it
if
block_table
is
None
:
if
block_table
is
None
:
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
v
=
v
.
contiguous
()
return
flash_attn_varlen_func
(
return
flash_attn_varlen_func
(
out
=
out
,
out
=
out
,
...
@@ -156,3 +190,265 @@ class xpu_ops:
...
@@ -156,3 +190,265 @@ class xpu_ops:
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
)
)
return
None
return
None
@
staticmethod
def
indexer_k_quant_and_cache
(
k
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
quant_block_size
:
int
,
scale_fmt
:
str
|
None
,
)
->
None
:
head_dim
=
k
.
shape
[
-
1
]
k
=
k
.
view
(
-
1
,
head_dim
)
# [total_tokens, head_dim]
def
group_quant_torch
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
|
None
=
None
,
column_major_scales
:
bool
=
False
,
out_q
:
torch
.
Tensor
|
None
=
None
,
use_ue8m0
:
bool
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
use_ue8m0
is
None
:
# Default fallback - could import is_deep_gemm_e8m0_used if needed
use_ue8m0
=
False
if
dtype
is
None
:
dtype
=
current_platform
.
fp8_dtype
()
# Validate inputs
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
(
f
"Last dimension
{
x
.
shape
[
-
1
]
}
must be divisible by "
f
"group_size
{
group_size
}
"
)
assert
x
.
stride
(
-
1
)
==
1
,
"Input tensor groups must be contiguous"
# Prepare output tensor
if
out_q
is
None
:
x_q
=
torch
.
empty_like
(
x
,
dtype
=
dtype
)
else
:
assert
out_q
.
shape
==
x
.
shape
x_q
=
out_q
# Reshape input for group processing
# Original shape: (..., last_dim)
# Target shape: (..., num_groups, group_size)
original_shape
=
x
.
shape
num_groups
=
original_shape
[
-
1
]
//
group_size
# Reshape to separate groups
group_shape
=
original_shape
[:
-
1
]
+
(
num_groups
,
group_size
)
x_grouped
=
x
.
view
(
group_shape
)
# Compute per-group absolute maximum values
# Shape: (..., num_groups)
abs_max
=
torch
.
amax
(
torch
.
abs
(
x_grouped
),
dim
=-
1
,
keepdim
=
False
)
abs_max
=
torch
.
maximum
(
abs_max
,
torch
.
tensor
(
eps
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
)
# Compute scales
FP8_MAX
=
torch
.
finfo
(
dtype
).
max
FP8_MIN
=
torch
.
finfo
(
dtype
).
min
scale_raw
=
abs_max
/
FP8_MAX
if
use_ue8m0
:
# For UE8M0 format, scales must be powers of 2
scales
=
torch
.
pow
(
2.0
,
torch
.
ceil
(
torch
.
log2
(
scale_raw
)))
else
:
scales
=
scale_raw
# Expand scales for broadcasting with grouped data
# Shape: (..., num_groups, 1)
scales_expanded
=
scales
.
unsqueeze
(
-
1
)
# Quantize the grouped data
x_scaled
=
x_grouped
/
scales_expanded
x_clamped
=
torch
.
clamp
(
x_scaled
,
FP8_MIN
,
FP8_MAX
)
x_quantized
=
x_clamped
.
to
(
dtype
)
# Reshape back to original shape
x_q
.
copy_
(
x_quantized
.
view
(
original_shape
))
# Prepare scales tensor in requested format
if
column_major_scales
:
# Column-major: (num_groups,) + batch_dims
# Transpose the scales to put group dimension first
scales_shape
=
(
num_groups
,)
+
original_shape
[:
-
1
]
x_s
=
scales
.
permute
(
-
1
,
*
range
(
len
(
original_shape
)
-
1
))
x_s
=
x_s
.
contiguous
().
view
(
scales_shape
)
else
:
# Row-major: batch_dims + (num_groups,)
x_s
=
scales
.
contiguous
()
# Ensure scales are float32
return
x_q
,
x_s
.
float
()
k_fp8
,
k_scale
=
group_quant_torch
(
k
,
group_size
=
quant_block_size
,
column_major_scales
=
False
,
use_ue8m0
=
(
scale_fmt
==
"ue8m0"
),
)
k_fp8_bytes
=
k_fp8
.
view
(
-
1
,
head_dim
).
view
(
torch
.
uint8
)
scale_bytes
=
k_scale
.
view
(
torch
.
uint8
).
view
(
-
1
,
4
)
k
=
torch
.
cat
(
[
k_fp8_bytes
,
scale_bytes
],
dim
=-
1
)
# [total_tokens, head_dim + 4]
slot_mapping
=
slot_mapping
.
flatten
()
# kv_cache: [num_block, block_size, head_dim + 4]
kv_cache
.
view
(
-
1
,
kv_cache
.
shape
[
-
1
]).
index_copy_
(
0
,
slot_mapping
,
k
)
@
staticmethod
def
cp_gather_indexer_k_quant_cache
(
kv_cache
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
dst_scale
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
)
->
None
:
"""
Args:
kv_cache: [num_blocks, block_size, cache_stride] - quantized KV cache
Layout per block: [k_values, scale_values]
- k_values: [block_size * head_dim]
- scale_values: [block_size * head_dim * 4 / quant_block_size]
dst_k: [num_tokens, head_dim] - output tensor for K values
dst_scale: [num_tokens, head_dim / quant_block_size * 4]
- output tensor for scale values
block_table: [batch_size, num_blocks] - block table for indexing
cu_seq_lens: [batch_size + 1] - cumulative sequence lengths
"""
batch_size
=
block_table
.
size
(
0
)
num_tokens
=
dst_k
.
size
(
0
)
head_dim
=
dst_k
.
size
(
1
)
cache_block_size
=
kv_cache
.
size
(
1
)
quant_block_size
=
head_dim
*
4
//
dst_scale
.
size
(
1
)
# For each token, find which batch it belongs to using searchsorted
token_indices
=
torch
.
arange
(
num_tokens
,
device
=
dst_k
.
device
)
+
1
# cu_seq_lens is [batch_size + 1], we need to find which interval each
# token belongs to
batch_indices
=
torch
.
searchsorted
(
cu_seq_lens
,
token_indices
)
-
1
batch_indices
=
torch
.
clamp
(
batch_indices
,
0
,
batch_size
-
1
)
# Calculate the in-batch sequence index for each token
inbatch_seq_indices
=
token_indices
-
cu_seq_lens
[
batch_indices
]
# Find which block each token belongs to
block_indices_in_table
=
inbatch_seq_indices
//
cache_block_size
physical_block_indices
=
block_table
[
batch_indices
,
block_indices_in_table
]
# Calculate the offset within each block
inblock_offsets
=
(
inbatch_seq_indices
-
1
)
%
cache_block_size
# Calculate strides
block_stride
=
kv_cache
.
stride
(
0
)
# stride for each block
# Flatten kv_cache for easier indexing
kv_cache_flat
=
kv_cache
.
view
(
-
1
)
# Calculate source offset for K values for all tokens (vectorized)
src_block_offsets
=
physical_block_indices
*
block_stride
src_k_offsets
=
src_block_offsets
+
inblock_offsets
*
head_dim
# Gather K values using advanced indexing
# Create indices for all elements we need to gather
k_indices
=
src_k_offsets
.
unsqueeze
(
1
)
+
torch
.
arange
(
head_dim
,
device
=
dst_k
.
device
)
dst_k
[:]
=
kv_cache_flat
[
k_indices
]
# Calculate source offset for scale values (vectorized)
# Scales are stored after all K values for each block
scale_size
=
head_dim
*
4
//
quant_block_size
src_scale_offsets
=
src_block_offsets
+
head_dim
+
inblock_offsets
*
scale_size
# Gather scale values
scale_indices
=
src_scale_offsets
.
unsqueeze
(
1
)
+
torch
.
arange
(
scale_size
,
device
=
dst_scale
.
device
)
dst_scale
[:]
=
kv_cache_flat
[
scale_indices
]
@
staticmethod
def
top_k_per_row_prefill
(
logits
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
raw_topk_indices
:
torch
.
Tensor
,
num_rows
:
int
,
stride0
:
int
,
strdide1
:
int
,
topk_tokens
:
int
,
)
->
torch
.
Tensor
:
real_topk
=
min
(
topk_tokens
,
logits
.
shape
[
-
1
])
topk_indices
=
logits
.
topk
(
real_topk
,
dim
=-
1
)[
1
].
to
(
torch
.
int32
)
topk_indices
-=
cu_seqlen_ks
[:,
None
]
mask_lo
=
topk_indices
>=
0
mask_hi
=
topk_indices
-
(
cu_seqlen_ke
-
cu_seqlen_ks
)[:,
None
]
<
0
mask
=
torch
.
full_like
(
topk_indices
,
False
,
dtype
=
torch
.
bool
,
device
=
topk_indices
.
device
)
mask
=
mask_lo
&
mask_hi
topk_indices
.
masked_fill_
(
~
mask
,
-
1
)
raw_topk_indices
[:
topk_indices
.
shape
[
0
],
:
topk_indices
.
shape
[
1
]]
=
(
topk_indices
)
@
staticmethod
def
top_k_per_row_decode
(
logits
:
torch
.
Tensor
,
next_n
:
int
,
seq_lens
:
torch
.
Tensor
,
raw_topk_indices
:
torch
.
Tensor
,
num_rows
:
int
,
stride0
:
int
,
stride1
:
int
,
topk_tokens
:
int
,
)
->
torch
.
Tensor
:
device
=
logits
.
device
batch_size
=
seq_lens
.
size
(
0
)
# padded query len
padded_num_tokens
=
batch_size
*
next_n
positions
=
(
torch
.
arange
(
logits
.
shape
[
-
1
],
device
=
device
)
.
unsqueeze
(
0
)
.
expand
(
batch_size
*
next_n
,
-
1
)
)
row_indices
=
torch
.
arange
(
padded_num_tokens
,
device
=
device
)
//
next_n
next_n_offset
=
torch
.
arange
(
padded_num_tokens
,
device
=
device
)
%
next_n
index_end_pos
=
(
seq_lens
[
row_indices
]
-
next_n
+
next_n_offset
).
unsqueeze
(
1
)
# index_end_pos: [B * N, 1]
mask
=
positions
<=
index_end_pos
# mask: [B * N, L]
logits
=
logits
.
masked_fill
(
~
mask
,
float
(
"-inf"
))
topk_indices
=
logits
.
topk
(
topk_tokens
,
dim
=-
1
)[
1
].
to
(
torch
.
int32
)
# [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices
[
topk_indices
>
index_end_pos
]
=
-
1
raw_topk_indices
[:
topk_indices
.
shape
[
0
],
:
topk_indices
.
shape
[
1
]]
=
(
topk_indices
)
@
staticmethod
def
register_ops_once
()
->
None
:
global
_OPS_REGISTERED
if
not
_OPS_REGISTERED
:
# register all the custom ops here
direct_register_custom_op
(
op_name
=
"xpu_ops_deepseek_scaling_rope"
,
op_func
=
_xpu_ops_deepseek_scaling_rope_impl
,
mutates_args
=
[],
fake_impl
=
_xpu_ops_deepseek_scaling_rope_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
_OPS_REGISTERED
=
True
xpu_ops
.
register_ops_once
()
\ No newline at end of file
vllm/config/compilation.py
View file @
eefa41c1
...
@@ -337,9 +337,10 @@ class DynamicShapesConfig:
...
@@ -337,9 +337,10 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
"""
assume_32_bit_indexing
:
bool
=
Tru
e
assume_32_bit_indexing
:
bool
=
Fals
e
"""
"""
whether all tensor sizes can use 32 bit indexing.
whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
"""
"""
def
compute_hash
(
self
)
->
str
:
def
compute_hash
(
self
)
->
str
:
...
...
vllm/config/speculative.py
View file @
eefa41c1
...
@@ -259,6 +259,16 @@ class SpeculativeConfig:
...
@@ -259,6 +259,16 @@ class SpeculativeConfig:
}
}
)
)
if
hf_config
.
architectures
[
0
]
==
"GlmOcrForConditionalGeneration"
:
hf_config
.
model_type
=
"glm_ocr_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
(
{
"num_hidden_layers"
:
0
,
"n_predict"
:
n_predict
,
"architectures"
:
[
"GlmOcrMTPModel"
],
}
)
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
if
hf_config
.
model_type
==
"ernie4_5_moe"
:
hf_config
.
model_type
=
"ernie_mtp"
hf_config
.
model_type
=
"ernie_mtp"
if
hf_config
.
model_type
==
"ernie_mtp"
:
if
hf_config
.
model_type
==
"ernie_mtp"
:
...
...
vllm/distributed/device_communicators/all2all.py
View file @
eefa41c1
...
@@ -72,7 +72,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
...
@@ -72,7 +72,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return
buffer
return
buffer
def
dispatch
(
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -96,6 +96,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
...
@@ -96,6 +96,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
)
)
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
extra_tensors
is
not
None
:
raise
NotImplementedError
(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size
=
self
.
tp_group
.
world_size
if
is_sequence_parallel
else
1
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
cu_tokens_across_sp_cpu
=
dp_metadata
.
cu_tokens_across_sp
(
sp_size
)
hidden_states
=
self
.
naive_multicast
(
hidden_states
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_weights
=
self
.
naive_multicast
(
topk_weights
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
topk_ids
=
self
.
naive_multicast
(
topk_ids
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
...
@@ -127,7 +155,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
...
@@ -127,7 +155,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def
__init__
(
self
,
cpu_group
,
tcp_store_group
=
None
):
def
__init__
(
self
,
cpu_group
,
tcp_store_group
=
None
):
super
().
__init__
(
cpu_group
,
tcp_store_group
)
super
().
__init__
(
cpu_group
,
tcp_store_group
)
def
dispatch
(
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -161,6 +189,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
...
@@ -161,6 +189,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
sizes
=
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
assert
sizes
is
not
None
dist_group
=
get_ep_group
()
if
is_sequence_parallel
else
get_dp_group
()
assert
sizes
[
dist_group
.
rank_in_group
]
==
hidden_states
.
shape
[
0
]
tensors_to_gather
=
[
hidden_states
,
topk_weights
,
topk_ids
]
if
extra_tensors
is
not
None
:
tensors_to_gather
.
extend
(
extra_tensors
)
gathered_tensors
=
dist_group
.
all_gatherv
(
tensors_to_gather
,
dim
=
0
,
sizes
=
sizes
,
)
hidden_states
=
gathered_tensors
[
0
]
topk_weights
=
gathered_tensors
[
1
]
topk_ids
=
gathered_tensors
[
2
]
if
extra_tensors
is
None
:
return
hidden_states
,
topk_weights
,
topk_ids
return
hidden_states
,
topk_weights
,
topk_ids
,
gathered_tensors
[
3
:]
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -200,7 +268,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
...
@@ -200,7 +268,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def
get_handle
(
self
,
kwargs
):
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
(
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -209,6 +277,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
...
@@ -209,6 +277,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
raise
NotImplementedError
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
vllm/distributed/device_communicators/base_device_communicator.py
View file @
eefa41c1
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
threading
import
threading
from
typing
import
Any
from
weakref
import
WeakValueDictionary
from
weakref
import
WeakValueDictionary
import
torch
import
torch
...
@@ -70,13 +69,32 @@ class All2AllManagerBase:
...
@@ -70,13 +69,32 @@ class All2AllManagerBase:
# and reuse it for the same config.
# and reuse it for the same config.
raise
NotImplementedError
raise
NotImplementedError
def
dispatch
(
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
Any
:
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise
NotImplementedError
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
# Subclasses should either:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
# - raise a clear error if extra_tensors is not supported.
...
@@ -312,7 +330,7 @@ class DeviceCommunicatorBase:
...
@@ -312,7 +330,7 @@ class DeviceCommunicatorBase:
for
module
in
moe_modules
:
for
module
in
moe_modules
:
module
.
maybe_init_modular_kernel
()
module
.
maybe_init_modular_kernel
()
def
dispatch
(
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -326,8 +344,29 @@ class DeviceCommunicatorBase:
...
@@ -326,8 +344,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
This is a no-op in the base class.
"""
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
router_logits
,
extra_tensors
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if
extra_tensors
is
not
None
:
return
hidden_states
,
topk_weights
,
topk_ids
,
extra_tensors
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -338,4 +377,4 @@ class DeviceCommunicatorBase:
...
@@ -338,4 +377,4 @@ class DeviceCommunicatorBase:
return
hidden_states
return
hidden_states
def
batch_isend_irecv
(
self
,
p2p_ops
:
list
):
def
batch_isend_irecv
(
self
,
p2p_ops
:
list
):
raise
NotImplementedError
raise
NotImplementedError
\ No newline at end of file
vllm/distributed/device_communicators/cpu_communicator.py
View file @
eefa41c1
...
@@ -151,29 +151,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
...
@@ -151,29 +151,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
def
dispatch
(
# type: ignore[override]
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
_router_logits
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
hidden_states
=
self
.
all2all_manager
.
combine
(
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
hidden_states
,
is_sequence_parallel
,
)
)
return
hidden_states
class
_CPUSHMDistributed
:
class
_CPUSHMDistributed
:
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
eefa41c1
...
@@ -396,7 +396,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -396,7 +396,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return
output_list
return
output_list
def
dispatch
(
# type: ignore[override]
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -406,20 +406,54 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -406,20 +406,54 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
_router_logits
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
hidden_states
=
self
.
all2all_manager
.
combine
(
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
hidden_states
,
is_sequence_parallel
,
)
)
def
batch_isend_irecv
(
self
,
p2p_ops
:
list
):
def
batch_isend_irecv
(
self
,
p2p_ops
:
list
):
...
@@ -427,4 +461,4 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -427,4 +461,4 @@ class CudaCommunicator(DeviceCommunicatorBase):
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
batch_isend_irecv
(
p2p_ops
)
pynccl_comm
.
batch_isend_irecv
(
p2p_ops
)
else
:
else
:
raise
ValueError
(
"No PyNCCL communicator found"
)
raise
ValueError
(
"No PyNCCL communicator found"
)
\ No newline at end of file
vllm/distributed/device_communicators/mnnvl_compat.py
View file @
eefa41c1
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
...
...
vllm/distributed/device_communicators/xpu_communicator.py
View file @
eefa41c1
...
@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
...
@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
)
->
None
:
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
dist
.
broadcast
(
input_
,
src
=
src
,
group
=
self
.
device_group
)
def
dispatch
(
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
_router_logits
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
is_sequence_parallel
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
extra_tensors
,
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
=
extra_tensors
,
)
)
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
hidden_states
=
self
.
all2all_manager
.
combine
(
return
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
hidden_states
,
)
is_sequence_parallel
,
return
hidden_states
)
\ No newline at end of file
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
eefa41c1
...
@@ -384,7 +384,9 @@ class TpKVTopology:
...
@@ -384,7 +384,9 @@ class TpKVTopology:
@
property
@
property
def
split_k_and_v
(
self
)
->
bool
:
def
split_k_and_v
(
self
)
->
bool
:
# Whether to register regions for K and V separately (when present).
# Whether to register regions for K and V separately (when present).
return
not
(
self
.
is_mla
or
self
.
is_kv_layout_blocks_first
)
return
not
(
self
.
_cross_layers_blocks
or
self
.
is_mla
or
self
.
is_kv_layout_blocks_first
)
@
property
@
property
def
tp_size
(
self
)
->
int
:
def
tp_size
(
self
)
->
int
:
...
@@ -554,4 +556,4 @@ def get_current_attn_backend(
...
@@ -554,4 +556,4 @@ def get_current_attn_backend(
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
vllm_config
:
VllmConfig
,
layer_names
:
list
[
str
]
|
None
=
None
)
->
type
[
AttentionBackend
]:
)
->
type
[
AttentionBackend
]:
"""Get the first attention backend for the given layers."""
"""Get the first attention backend for the given layers."""
return
get_current_attn_backends
(
vllm_config
,
layer_names
)[
0
]
return
get_current_attn_backends
(
vllm_config
,
layer_names
)[
0
]
\ No newline at end of file
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
eefa41c1
...
@@ -56,7 +56,7 @@ from vllm.logger import init_logger
...
@@ -56,7 +56,7 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.utils.network_utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backend
import
AttentionBackend
,
AttentionMetadata
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
from
vllm.v1.kv_cache_interface
import
(
...
@@ -186,7 +186,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
...
@@ -186,7 +186,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def
compute_nixl_compatibility_hash
(
def
compute_nixl_compatibility_hash
(
vllm_config
:
VllmConfig
,
attn_backend_name
:
str
vllm_config
:
VllmConfig
,
attn_backend_name
:
str
,
cross_layers_blocks
:
bool
)
->
str
:
)
->
str
:
"""
"""
Compute compatibility hash for NIXL KV transfer.
Compute compatibility hash for NIXL KV transfer.
...
@@ -1164,12 +1164,9 @@ class NixlConnectorWorker:
...
@@ -1164,12 +1164,9 @@ class NixlConnectorWorker:
logger
.
info
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
info
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
info
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
logger
.
info
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
self
.
compat_hash
=
compute_nixl_compatibility_hash
(
# lazy initialized in register_kv_caches
self
.
vllm_config
,
self
.
backend_name
self
.
compat_hash
:
str
|
None
=
None
)
self
.
kv_topo
:
TpKVTopology
|
None
=
None
self
.
enforce_compat_hash
=
self
.
kv_transfer_config
.
get_from_extra_config
(
"enforce_handshake_compat"
,
True
)
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_block_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
block_size
}
self
.
_block_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
block_size
}
...
@@ -1184,7 +1181,6 @@ class NixlConnectorWorker:
...
@@ -1184,7 +1181,6 @@ class NixlConnectorWorker:
self
.
enforce_compat_hash
=
self
.
kv_transfer_config
.
get_from_extra_config
(
self
.
enforce_compat_hash
=
self
.
kv_transfer_config
.
get_from_extra_config
(
"enforce_handshake_compat"
,
True
"enforce_handshake_compat"
,
True
)
)
self
.
_physical_blocks_per_logical_kv_block
=
1
def
_sync_block_size_with_kernel
(
self
)
->
None
:
def
_sync_block_size_with_kernel
(
self
)
->
None
:
backends
=
get_current_attn_backends
(
self
.
vllm_config
)
backends
=
get_current_attn_backends
(
self
.
vllm_config
)
...
@@ -1232,6 +1228,7 @@ class NixlConnectorWorker:
...
@@ -1232,6 +1228,7 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
# this happens to be the same single rank_i.
assert
self
.
kv_topo
is
not
None
p_remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks
(
remote_tp_size
)
p_remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks
(
remote_tp_size
)
remote_rank_to_agent_name
=
{}
remote_rank_to_agent_name
=
{}
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
...
@@ -1269,6 +1266,7 @@ class NixlConnectorWorker:
...
@@ -1269,6 +1266,7 @@ class NixlConnectorWorker:
)
)
# Check compatibility hash BEFORE decoding agent metadata
# Check compatibility hash BEFORE decoding agent metadata
assert
self
.
compat_hash
is
not
None
if
(
if
(
self
.
enforce_compat_hash
self
.
enforce_compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
...
@@ -1547,7 +1545,6 @@ class NixlConnectorWorker:
...
@@ -1547,7 +1545,6 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v
=
self
.
kv_topo
.
split_k_and_v
tensor_size_bytes
=
None
tensor_size_bytes
=
None
# Enable different block lengths for different layers *only* when MLA is used.
# Enable different block lengths for different layers *only* when MLA is used.
...
@@ -1698,6 +1695,7 @@ class NixlConnectorWorker:
...
@@ -1698,6 +1695,7 @@ class NixlConnectorWorker:
ssm_sizes
=
self
.
_mamba_ssm_size
,
ssm_sizes
=
self
.
_mamba_ssm_size
,
)
)
# Wrap metadata in payload with hash for defensive decoding
# Wrap metadata in payload with hash for defensive decoding
assert
self
.
compat_hash
is
not
None
encoder
=
msgspec
.
msgpack
.
Encoder
()
encoder
=
msgspec
.
msgpack
.
Encoder
()
self
.
xfer_handshake_metadata
=
NixlHandshakePayload
(
self
.
xfer_handshake_metadata
=
NixlHandshakePayload
(
compatibility_hash
=
self
.
compat_hash
,
compatibility_hash
=
self
.
compat_hash
,
...
@@ -2177,6 +2175,7 @@ class NixlConnectorWorker:
...
@@ -2177,6 +2175,7 @@ class NixlConnectorWorker:
if
len
(
self
.
device_kv_caches
)
==
0
:
if
len
(
self
.
device_kv_caches
)
==
0
:
return
return
assert
block_size_ratio
>=
1
,
"Only nP < nD supported currently."
assert
block_size_ratio
>=
1
,
"Only nP < nD supported currently."
assert
self
.
kv_topo
is
not
None
if
self
.
enable_permute_local_kv
and
block_size_ratio
>
1
:
if
self
.
enable_permute_local_kv
and
block_size_ratio
>
1
:
logger
.
debug
(
logger
.
debug
(
"Post-processing device kv cache on receive by converting "
"Post-processing device kv cache on receive by converting "
...
@@ -2196,7 +2195,7 @@ class NixlConnectorWorker:
...
@@ -2196,7 +2195,7 @@ class NixlConnectorWorker:
block_size_ratio
,
block_size_ratio
,
)
)
split_k_and_v
=
not
(
self
.
use_mla
or
self
.
kv_topo
.
is_kv_layout_blocks_first
)
split_k_and_v
=
self
.
kv_topo
.
split_k_and_v
for
block_ids
in
block_ids_list
:
for
block_ids
in
block_ids_list
:
indices
=
torch
.
tensor
(
block_ids
,
device
=
self
.
device_type
,
dtype
=
torch
.
long
)
indices
=
torch
.
tensor
(
block_ids
,
device
=
self
.
device_type
,
dtype
=
torch
.
long
)
...
@@ -2221,6 +2220,7 @@ class NixlConnectorWorker:
...
@@ -2221,6 +2220,7 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
to track which workers are done.
"""
"""
assert
self
.
kv_topo
is
not
None
done_sending
=
self
.
_get_new_notifs
()
done_sending
=
self
.
_get_new_notifs
()
done_recving
=
self
.
_pop_done_transfers
(
self
.
_recving_transfers
)
done_recving
=
self
.
_pop_done_transfers
(
self
.
_recving_transfers
)
...
@@ -2291,6 +2291,7 @@ class NixlConnectorWorker:
...
@@ -2291,6 +2291,7 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
for all consumers to be done pulling.
"""
"""
assert
self
.
kv_topo
is
not
None
notified_req_ids
:
set
[
str
]
=
set
()
notified_req_ids
:
set
[
str
]
=
set
()
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notif
in
notifs
:
for
notif
in
notifs
:
...
@@ -2451,7 +2452,7 @@ class NixlConnectorWorker:
...
@@ -2451,7 +2452,7 @@ class NixlConnectorWorker:
self
.
_reqs_to_send
[
req_id
]
=
expiration_time
self
.
_reqs_to_send
[
req_id
]
=
expiration_time
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
assert
meta
.
remote
is
not
None
assert
meta
.
remote
is
not
None
and
self
.
kv_topo
is
not
None
remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks_from_engine_id
(
remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks_from_engine_id
(
meta
.
remote
.
engine_id
meta
.
remote
.
engine_id
)
)
...
@@ -2782,6 +2783,7 @@ class NixlConnectorWorker:
...
@@ -2782,6 +2783,7 @@ class NixlConnectorWorker:
+-------------------+ +--------------------+
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
|1st_split-2nd_split| |1st_split-2nd_split |
"""
"""
assert
self
.
kv_topo
is
not
None
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# For indexing only half (either just the K or V part).
# For indexing only half (either just the K or V part).
if
mamba_view
:
if
mamba_view
:
...
@@ -3103,4 +3105,4 @@ class NixlPromMetrics(KVConnectorPromMetrics):
...
@@ -3103,4 +3105,4 @@ class NixlPromMetrics(KVConnectorPromMetrics):
[
"num_failed_transfers"
,
"num_failed_notifications"
,
"num_kv_expired_reqs"
],
[
"num_failed_transfers"
,
"num_failed_notifications"
,
"num_kv_expired_reqs"
],
):
):
for
list_item
in
transfer_stats_data
[
counter_item_key
]:
for
list_item
in
transfer_stats_data
[
counter_item_key
]:
counter_obj
[
engine_idx
].
inc
(
list_item
)
counter_obj
[
engine_idx
].
inc
(
list_item
)
\ No newline at end of file
vllm/distributed/parallel_state.py
View file @
eefa41c1
...
@@ -1065,7 +1065,7 @@ class GroupCoordinator:
...
@@ -1065,7 +1065,7 @@ class GroupCoordinator:
if
self
.
device_communicator
is
not
None
:
if
self
.
device_communicator
is
not
None
:
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
def
dispatch
(
def
dispatch
_router_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -1076,7 +1076,7 @@ class GroupCoordinator:
...
@@ -1076,7 +1076,7 @@ class GroupCoordinator:
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
):
if
self
.
device_communicator
is
not
None
:
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
# type: ignore[call-arg]
return
self
.
device_communicator
.
dispatch
_router_logits
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
is_sequence_parallel
,
is_sequence_parallel
,
...
@@ -1085,6 +1085,28 @@ class GroupCoordinator:
...
@@ -1085,6 +1085,28 @@ class GroupCoordinator:
else
:
else
:
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
):
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
hidden_states
,
topk_weights
,
topk_ids
,
is_sequence_parallel
,
extra_tensors
,
)
else
:
return
hidden_states
,
topk_weights
,
topk_ids
def
combine
(
def
combine
(
self
,
hidden_states
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
,
is_sequence_parallel
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -2090,4 +2112,4 @@ def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
...
@@ -2090,4 +2112,4 @@ def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
if
is_same_node
and
node_assignment
[
other_rank
]
==
0
:
if
is_same_node
and
node_assignment
[
other_rank
]
==
0
:
node_assignment
[
other_rank
]
=
next_node_id
node_assignment
[
other_rank
]
=
next_node_id
return
next_node_id
return
next_node_id
\ No newline at end of file
Prev
1
2
3
4
5
6
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment