Commit eefa41c1 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.18.0

parent 82155c76
...@@ -266,7 +266,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -266,7 +266,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
), ),
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"), "Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"),
"ExaoneMoEForCausalLM": _HfExamplesInfo( "ExaoneMoEForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.0.0" "LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.1.0"
), ),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
...@@ -283,11 +283,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -283,11 +283,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5"), "Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5"),
"Glm4MoeLiteForCausalLM": _HfExamplesInfo( "Glm4MoeLiteForCausalLM": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash", "zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0.dev", min_transformers_version="5.0.0",
is_available_online=False,
),
"GlmMoeDsaForCausalLM": _HfExamplesInfo(
"zai-org/GLM-5", min_transformers_version="5.0.1", is_available_online=False
), ),
"GlmMoeDsaForCausalLM": _HfExamplesInfo( "GlmMoeDsaForCausalLM": _HfExamplesInfo(
"zai-org/GLM-5", min_transformers_version="5.0.1", is_available_online=False "zai-org/GLM-5", min_transformers_version="5.0.1", is_available_online=False
...@@ -743,7 +739,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -743,7 +739,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only] # [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo( "AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev" "nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0"
), ),
"MusicFlamingoForConditionalGeneration": _HfExamplesInfo( "MusicFlamingoForConditionalGeneration": _HfExamplesInfo(
"nvidia/music-flamingo-2601-hf", min_transformers_version="5.0.0.dev" "nvidia/music-flamingo-2601-hf", min_transformers_version="5.0.0.dev"
...@@ -1237,7 +1233,13 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1237,7 +1233,13 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeLiteMTPModel": _HfExamplesInfo( "Glm4MoeLiteMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash", "zai-org/GLM-4.7-Flash",
speculative_model="zai-org/GLM-4.7-Flash", speculative_model="zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0",
),
"GlmOcrMTPModel": _HfExamplesInfo(
"zai-org/GLM-OCR",
speculative_model="zai-org/GLM-OCR",
is_available_online=False, is_available_online=False,
min_transformers_version="5.1.0",
), ),
"LongCatFlashMTPModel": _HfExamplesInfo( "LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat", "meituan-longcat/LongCat-Flash-Chat",
...@@ -1282,27 +1284,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1282,27 +1284,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_BACKEND_MODELS = { _TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": _HfExamplesInfo( "TransformersEmbeddingModel": _HfExamplesInfo(
"BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0.dev" "BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0"
), ),
"TransformersForSequenceClassification": _HfExamplesInfo( "TransformersForSequenceClassification": _HfExamplesInfo(
"papluca/xlm-roberta-base-language-detection", "papluca/xlm-roberta-base-language-detection",
min_transformers_version="5.0.0.dev", min_transformers_version="5.0.0",
), ),
"TransformersForCausalLM": _HfExamplesInfo( "TransformersForCausalLM": _HfExamplesInfo(
"hmellor/Ilama-3.2-1B", trust_remote_code=True "hmellor/Ilama-3.2-1B", trust_remote_code=True
), ),
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo( "TransformersMoEForCausalLM": _HfExamplesInfo(
"allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0.dev" "allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0"
), ),
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo( "TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0.dev" "Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0"
), ),
"TransformersMoEEmbeddingModel": _HfExamplesInfo( "TransformersMoEEmbeddingModel": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0.dev" "Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
), ),
"TransformersMoEForSequenceClassification": _HfExamplesInfo( "TransformersMoEForSequenceClassification": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0.dev" "Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
), ),
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"), "TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo( "TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
......
...@@ -76,7 +76,7 @@ def test_models( ...@@ -76,7 +76,7 @@ def test_models(
from packaging.version import Version from packaging.version import Version
installed = Version(transformers.__version__) installed = Version(transformers.__version__)
required = Version("5.0.0.dev") required = Version("5.0.0")
if model == "allenai/OLMoE-1B-7B-0924" and installed < required: if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
pytest.skip( pytest.skip(
"MoE models with the Transformers modeling backend require " "MoE models with the Transformers modeling backend require "
...@@ -237,4 +237,4 @@ def test_pooling(hf_runner, vllm_runner, example_prompts, arch): ...@@ -237,4 +237,4 @@ def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
embeddings_1_lst=vllm_outputs, embeddings_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
) )
\ No newline at end of file
...@@ -36,7 +36,7 @@ class MyGemma2Embedding(nn.Module): ...@@ -36,7 +36,7 @@ class MyGemma2Embedding(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -59,4 +59,4 @@ class MyGemma2Embedding(nn.Module): ...@@ -59,4 +59,4 @@ class MyGemma2Embedding(nn.Module):
weights = ( weights = (
(name, data) for name, data in weights if not name.startswith("lm_head.") (name, data) for name, data in weights if not name.startswith("lm_head.")
) )
return self.model.load_weights(weights) return self.model.load_weights(weights)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the UvicornAccessLogFilter class.
"""
import logging
from vllm.logging_utils.access_log_filter import (
UvicornAccessLogFilter,
create_uvicorn_log_config,
)
class TestUvicornAccessLogFilter:
"""Test cases for UvicornAccessLogFilter."""
def test_filter_allows_all_when_no_excluded_paths(self):
"""Filter should allow all logs when no paths are excluded."""
filter = UvicornAccessLogFilter(excluded_paths=[])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/v1/completions", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is True
def test_filter_allows_all_when_excluded_paths_is_none(self):
"""Filter should allow all logs when excluded_paths is None."""
filter = UvicornAccessLogFilter(excluded_paths=None)
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is True
def test_filter_excludes_health_endpoint(self):
"""Filter should exclude /health endpoint when configured."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is False
def test_filter_excludes_metrics_endpoint(self):
"""Filter should exclude /metrics endpoint when configured."""
filter = UvicornAccessLogFilter(excluded_paths=["/metrics"])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/metrics", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is False
def test_filter_allows_non_excluded_endpoints(self):
"""Filter should allow endpoints not in the excluded list."""
filter = UvicornAccessLogFilter(excluded_paths=["/health", "/metrics"])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "POST", "/v1/completions", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is True
def test_filter_excludes_multiple_endpoints(self):
"""Filter should exclude multiple configured endpoints."""
filter = UvicornAccessLogFilter(excluded_paths=["/health", "/metrics", "/ping"])
# Test /health
record_health = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_health) is False
# Test /metrics
record_metrics = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/metrics", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_metrics) is False
# Test /ping
record_ping = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/ping", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_ping) is False
def test_filter_with_query_parameters(self):
"""Filter should exclude endpoints even with query parameters."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health?verbose=true", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is False
def test_filter_different_http_methods(self):
"""Filter should exclude endpoints regardless of HTTP method."""
filter = UvicornAccessLogFilter(excluded_paths=["/ping"])
# Test GET
record_get = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/ping", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_get) is False
# Test POST
record_post = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "POST", "/ping", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_post) is False
def test_filter_with_different_status_codes(self):
"""Filter should exclude endpoints regardless of status code."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
for status_code in [200, 500, 503]:
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health", "1.1", status_code),
exc_info=None,
)
assert filter.filter(record) is False
class TestCreateUvicornLogConfig:
"""Test cases for create_uvicorn_log_config function."""
def test_creates_valid_config_structure(self):
"""Config should have required logging configuration keys."""
config = create_uvicorn_log_config(excluded_paths=["/health"])
assert "version" in config
assert config["version"] == 1
assert "disable_existing_loggers" in config
assert "formatters" in config
assert "handlers" in config
assert "loggers" in config
assert "filters" in config
def test_config_includes_access_log_filter(self):
"""Config should include the access log filter."""
config = create_uvicorn_log_config(excluded_paths=["/health", "/metrics"])
assert "access_log_filter" in config["filters"]
filter_config = config["filters"]["access_log_filter"]
assert filter_config["()"] == UvicornAccessLogFilter
assert filter_config["excluded_paths"] == ["/health", "/metrics"]
def test_config_applies_filter_to_access_handler(self):
"""Config should apply the filter to the access handler."""
config = create_uvicorn_log_config(excluded_paths=["/health"])
assert "access" in config["handlers"]
assert "filters" in config["handlers"]["access"]
assert "access_log_filter" in config["handlers"]["access"]["filters"]
def test_config_with_custom_log_level(self):
"""Config should respect custom log level."""
config = create_uvicorn_log_config(
excluded_paths=["/health"], log_level="debug"
)
assert config["loggers"]["uvicorn"]["level"] == "DEBUG"
assert config["loggers"]["uvicorn.access"]["level"] == "DEBUG"
assert config["loggers"]["uvicorn.error"]["level"] == "DEBUG"
def test_config_with_empty_excluded_paths(self):
"""Config should work with empty excluded paths."""
config = create_uvicorn_log_config(excluded_paths=[])
assert config["filters"]["access_log_filter"]["excluded_paths"] == []
def test_config_with_none_excluded_paths(self):
"""Config should work with None excluded paths."""
config = create_uvicorn_log_config(excluded_paths=None)
assert config["filters"]["access_log_filter"]["excluded_paths"] == []
class TestIntegration:
"""Integration tests for the access log filter."""
def test_filter_with_real_logger(self):
"""Test filter works with a real Python logger simulating uvicorn."""
# Create a logger with our filter (simulating uvicorn.access)
logger = logging.getLogger("uvicorn.access")
logger.setLevel(logging.INFO)
# Clear any existing handlers
logger.handlers = []
# Create a custom handler that tracks messages
logged_messages: list[str] = []
class TrackingHandler(logging.Handler):
def emit(self, record):
logged_messages.append(record.getMessage())
handler = TrackingHandler()
handler.setLevel(logging.INFO)
filter = UvicornAccessLogFilter(excluded_paths=["/health", "/metrics"])
handler.addFilter(filter)
logger.addHandler(handler)
# Log using uvicorn's format with args tuple
# Format: '%s - "%s %s HTTP/%s" %d'
logger.info(
'%s - "%s %s HTTP/%s" %d',
"127.0.0.1:12345",
"GET",
"/health",
"1.1",
200,
)
logger.info(
'%s - "%s %s HTTP/%s" %d',
"127.0.0.1:12345",
"GET",
"/v1/completions",
"1.1",
200,
)
logger.info(
'%s - "%s %s HTTP/%s" %d',
"127.0.0.1:12345",
"GET",
"/metrics",
"1.1",
200,
)
logger.info(
'%s - "%s %s HTTP/%s" %d',
"127.0.0.1:12345",
"POST",
"/v1/chat/completions",
"1.1",
200,
)
# Verify only non-excluded endpoints were logged
assert len(logged_messages) == 2
assert "/v1/completions" in logged_messages[0]
assert "/v1/chat/completions" in logged_messages[1]
def test_filter_allows_non_uvicorn_access_logs(self):
"""Test filter allows logs from non-uvicorn.access loggers."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
# Log record from a different logger name
record = logging.LogRecord(
name="uvicorn.error",
level=logging.INFO,
pathname="",
lineno=0,
msg="Some error message about /health",
args=(),
exc_info=None,
)
# Should allow because it's not from uvicorn.access
assert filter.filter(record) is True
def test_filter_handles_malformed_args(self):
"""Test filter handles log records with unexpected args format."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
# Log record with insufficient args
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg="Some message",
args=("only", "two"),
exc_info=None,
)
# Should allow because args doesn't have expected format
assert filter.filter(record) is True
def test_filter_handles_non_tuple_args(self):
"""Test filter handles log records with non-tuple args."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
# Log record with None args
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg="Some message without args",
args=None,
exc_info=None,
)
# Should allow because args is None
assert filter.filter(record) is True
\ No newline at end of file
...@@ -383,7 +383,7 @@ def _run_eagle_correctness( ...@@ -383,7 +383,7 @@ def _run_eagle_correctness(
from packaging.version import Version from packaging.version import Version
installed = Version(transformers.__version__) installed = Version(transformers.__version__)
required = Version("5.0.0.dev") required = Version("5.0.0")
if installed < required: if installed < required:
pytest.skip( pytest.skip(
"Eagle3 with the Transformers modeling backend requires " "Eagle3 with the Transformers modeling backend requires "
...@@ -1030,4 +1030,4 @@ def compute_acceptance_len(metrics: list[Metric]) -> float: ...@@ -1030,4 +1030,4 @@ def compute_acceptance_len(metrics: list[Metric]) -> float:
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
if n_drafts == 0: if n_drafts == 0:
return 1 return 1
return 1 + (n_accepted_toks / n_drafts) return 1 + (n_accepted_toks / n_drafts)
\ No newline at end of file
...@@ -59,9 +59,9 @@ fi ...@@ -59,9 +59,9 @@ fi
# Build the kv-transfer-config once # Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}' KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
else else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}" KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
fi fi
# Models to run # Models to run
...@@ -295,4 +295,4 @@ for model in "${MODELS[@]}"; do ...@@ -295,4 +295,4 @@ for model in "${MODELS[@]}"; do
run_tests_for_model "$model" run_tests_for_model "$model"
done done
echo "All tests completed!" echo "All tests completed!"
\ No newline at end of file
...@@ -18,8 +18,12 @@ import ray ...@@ -18,8 +18,12 @@ import ray
import torch import torch
from vllm import LLM from vllm import LLM
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
...@@ -58,6 +62,8 @@ from vllm.v1.kv_cache_interface import ( ...@@ -58,6 +62,8 @@ from vllm.v1.kv_cache_interface import (
) )
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup
from .utils import ( from .utils import (
create_request, create_request,
...@@ -1498,44 +1504,6 @@ def test_register_kv_caches( ...@@ -1498,44 +1504,6 @@ def test_register_kv_caches(
backend_cls = TritonAttentionBackend backend_cls = TritonAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with ( with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
...@@ -1716,14 +1684,13 @@ def test_register_kv_caches( ...@@ -1716,14 +1684,13 @@ def test_register_kv_caches(
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size # Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, ( assert len(blocks_data) == expected_blocks_count, (
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
) )
num_blocks = 2 if connector.prefer_cross_layer_blocks:
if is_blocks_first: num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks // 2 expected_block_len = expected_tensor_size // num_blocks
else: else:
num_blocks = kv_cache_config.num_blocks num_blocks = kv_cache_config.num_blocks
if is_blocks_first: if is_blocks_first:
...@@ -2360,7 +2327,9 @@ def test_compatibility_hash_validation( ...@@ -2360,7 +2327,9 @@ def test_compatibility_hash_validation(
) )
) )
remote_hash = compute_nixl_compatibility_hash( remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, decode_worker.backend_name remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
) )
prefill_block_size = config_overrides.get("block_size", 16) prefill_block_size = config_overrides.get("block_size", 16)
...@@ -2497,4 +2466,4 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2497,4 +2466,4 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
port=1234, port=1234,
remote_tp_size=1, remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
) )
\ No newline at end of file
...@@ -3044,13 +3044,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): ...@@ -3044,13 +3044,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
class CPUDNNLGEMMHandler: class CPUDNNLGEMMHandler:
def __init__(self) -> None: def __init__(self) -> None:
self.handler: int | None = None self.handler_tensor: torch.Tensor | None = None
self.n = -1 self.n = -1
self.k = -1 self.k = -1
def __del__(self): def __del__(self):
if self.handler is not None: if self.handler_tensor is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler) torch.ops._C.release_dnnl_matmul_handler(self.handler_tensor.item())
_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler")) _supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
...@@ -3066,8 +3066,10 @@ def create_onednn_mm( ...@@ -3066,8 +3066,10 @@ def create_onednn_mm(
) -> CPUDNNLGEMMHandler: ) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler() handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size() handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_mm_handler( # store the handler pointer in a tensor it doesn't get inlined
weight, primitive_cache_size handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_mm_handler(weight, primitive_cache_size),
dtype=torch.int64,
) )
return handler return handler
...@@ -3079,7 +3081,7 @@ def onednn_mm( ...@@ -3079,7 +3081,7 @@ def onednn_mm(
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
torch.ops._C.onednn_mm( torch.ops._C.onednn_mm(
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor
) )
return output return output
...@@ -3095,8 +3097,17 @@ def create_onednn_scaled_mm( ...@@ -3095,8 +3097,17 @@ def create_onednn_scaled_mm(
) -> CPUDNNLGEMMHandler: ) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler() handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size() handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( # store the handler pointer in a tensor so it doesn't get inlined
weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_scaled_mm_handler(
weight,
weight_scales,
output_type,
dynamic_quant,
use_azp,
primitive_cache_size,
),
dtype=torch.int64,
) )
return handler return handler
...@@ -3149,11 +3160,15 @@ def onednn_scaled_mm( ...@@ -3149,11 +3160,15 @@ def onednn_scaled_mm(
bias: torch.Tensor | None, bias: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
torch.ops._C.onednn_scaled_mm( torch.ops._C.onednn_scaled_mm(
output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler output,
x,
input_scale,
input_zp,
input_zp_adj,
bias,
dnnl_handler.handler_tensor,
) )
return output
def cpu_attn_get_scheduler_metadata( def cpu_attn_get_scheduler_metadata(
num_reqs: int, num_reqs: int,
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -53,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"): ...@@ -53,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
return torch.empty((M, N), dtype=input.dtype, device=input.device) return torch.empty((M, N), dtype=input.dtype, device=input.device)
def _xpu_ops_deepseek_scaling_rope_impl(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
assert key is not None
return torch.ops._xpu_C.deepseek_scaling_rope(
positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style
)
def _xpu_ops_deepseek_scaling_rope_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return query, key
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
class xpu_ops: class xpu_ops:
@staticmethod @staticmethod
def flash_attn_varlen_func( def flash_attn_varlen_func(
...@@ -105,9 +138,10 @@ class xpu_ops: ...@@ -105,9 +138,10 @@ class xpu_ops:
assert len(window_size) == 2 assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1]) # noqa: F841 real_window_size = (window_size[0], window_size[1]) # noqa: F841
# In encode attention, v maybe not contiguous and current # In encode attention, k and v maybe not contiguous and current
# kernel can't handle it # kernel can't handle it
if block_table is None: if block_table is None:
k = k.contiguous()
v = v.contiguous() v = v.contiguous()
return flash_attn_varlen_func( return flash_attn_varlen_func(
out=out, out=out,
...@@ -156,3 +190,265 @@ class xpu_ops: ...@@ -156,3 +190,265 @@ class xpu_ops:
"get_scheduler_metadata is not implemented for xpu_ops, returning None." "get_scheduler_metadata is not implemented for xpu_ops, returning None."
) )
return None return None
@staticmethod
def indexer_k_quant_and_cache(
k: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
) -> None:
head_dim = k.shape[-1]
k = k.view(-1, head_dim) # [total_tokens, head_dim]
def group_quant_torch(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype | None = None,
column_major_scales: bool = False,
out_q: torch.Tensor | None = None,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if use_ue8m0 is None:
# Default fallback - could import is_deep_gemm_e8m0_used if needed
use_ue8m0 = False
if dtype is None:
dtype = current_platform.fp8_dtype()
# Validate inputs
assert x.shape[-1] % group_size == 0, (
f"Last dimension {x.shape[-1]} must be divisible by "
f"group_size {group_size}"
)
assert x.stride(-1) == 1, "Input tensor groups must be contiguous"
# Prepare output tensor
if out_q is None:
x_q = torch.empty_like(x, dtype=dtype)
else:
assert out_q.shape == x.shape
x_q = out_q
# Reshape input for group processing
# Original shape: (..., last_dim)
# Target shape: (..., num_groups, group_size)
original_shape = x.shape
num_groups = original_shape[-1] // group_size
# Reshape to separate groups
group_shape = original_shape[:-1] + (num_groups, group_size)
x_grouped = x.view(group_shape)
# Compute per-group absolute maximum values
# Shape: (..., num_groups)
abs_max = torch.amax(torch.abs(x_grouped), dim=-1, keepdim=False)
abs_max = torch.maximum(
abs_max, torch.tensor(eps, device=x.device, dtype=x.dtype)
)
# Compute scales
FP8_MAX = torch.finfo(dtype).max
FP8_MIN = torch.finfo(dtype).min
scale_raw = abs_max / FP8_MAX
if use_ue8m0:
# For UE8M0 format, scales must be powers of 2
scales = torch.pow(2.0, torch.ceil(torch.log2(scale_raw)))
else:
scales = scale_raw
# Expand scales for broadcasting with grouped data
# Shape: (..., num_groups, 1)
scales_expanded = scales.unsqueeze(-1)
# Quantize the grouped data
x_scaled = x_grouped / scales_expanded
x_clamped = torch.clamp(x_scaled, FP8_MIN, FP8_MAX)
x_quantized = x_clamped.to(dtype)
# Reshape back to original shape
x_q.copy_(x_quantized.view(original_shape))
# Prepare scales tensor in requested format
if column_major_scales:
# Column-major: (num_groups,) + batch_dims
# Transpose the scales to put group dimension first
scales_shape = (num_groups,) + original_shape[:-1]
x_s = scales.permute(-1, *range(len(original_shape) - 1))
x_s = x_s.contiguous().view(scales_shape)
else:
# Row-major: batch_dims + (num_groups,)
x_s = scales.contiguous()
# Ensure scales are float32
return x_q, x_s.float()
k_fp8, k_scale = group_quant_torch(
k,
group_size=quant_block_size,
column_major_scales=False,
use_ue8m0=(scale_fmt == "ue8m0"),
)
k_fp8_bytes = k_fp8.view(-1, head_dim).view(torch.uint8)
scale_bytes = k_scale.view(torch.uint8).view(-1, 4)
k = torch.cat(
[k_fp8_bytes, scale_bytes], dim=-1
) # [total_tokens, head_dim + 4]
slot_mapping = slot_mapping.flatten()
# kv_cache: [num_block, block_size, head_dim + 4]
kv_cache.view(-1, kv_cache.shape[-1]).index_copy_(0, slot_mapping, k)
@staticmethod
def cp_gather_indexer_k_quant_cache(
kv_cache: torch.Tensor,
dst_k: torch.Tensor,
dst_scale: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
) -> None:
"""
Args:
kv_cache: [num_blocks, block_size, cache_stride] - quantized KV cache
Layout per block: [k_values, scale_values]
- k_values: [block_size * head_dim]
- scale_values: [block_size * head_dim * 4 / quant_block_size]
dst_k: [num_tokens, head_dim] - output tensor for K values
dst_scale: [num_tokens, head_dim / quant_block_size * 4]
- output tensor for scale values
block_table: [batch_size, num_blocks] - block table for indexing
cu_seq_lens: [batch_size + 1] - cumulative sequence lengths
"""
batch_size = block_table.size(0)
num_tokens = dst_k.size(0)
head_dim = dst_k.size(1)
cache_block_size = kv_cache.size(1)
quant_block_size = head_dim * 4 // dst_scale.size(1)
# For each token, find which batch it belongs to using searchsorted
token_indices = torch.arange(num_tokens, device=dst_k.device) + 1
# cu_seq_lens is [batch_size + 1], we need to find which interval each
# token belongs to
batch_indices = torch.searchsorted(cu_seq_lens, token_indices) - 1
batch_indices = torch.clamp(batch_indices, 0, batch_size - 1)
# Calculate the in-batch sequence index for each token
inbatch_seq_indices = token_indices - cu_seq_lens[batch_indices]
# Find which block each token belongs to
block_indices_in_table = inbatch_seq_indices // cache_block_size
physical_block_indices = block_table[batch_indices, block_indices_in_table]
# Calculate the offset within each block
inblock_offsets = (inbatch_seq_indices - 1) % cache_block_size
# Calculate strides
block_stride = kv_cache.stride(0) # stride for each block
# Flatten kv_cache for easier indexing
kv_cache_flat = kv_cache.view(-1)
# Calculate source offset for K values for all tokens (vectorized)
src_block_offsets = physical_block_indices * block_stride
src_k_offsets = src_block_offsets + inblock_offsets * head_dim
# Gather K values using advanced indexing
# Create indices for all elements we need to gather
k_indices = src_k_offsets.unsqueeze(1) + torch.arange(
head_dim, device=dst_k.device
)
dst_k[:] = kv_cache_flat[k_indices]
# Calculate source offset for scale values (vectorized)
# Scales are stored after all K values for each block
scale_size = head_dim * 4 // quant_block_size
src_scale_offsets = src_block_offsets + head_dim + inblock_offsets * scale_size
# Gather scale values
scale_indices = src_scale_offsets.unsqueeze(1) + torch.arange(
scale_size, device=dst_scale.device
)
dst_scale[:] = kv_cache_flat[scale_indices]
@staticmethod
def top_k_per_row_prefill(
logits: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
raw_topk_indices: torch.Tensor,
num_rows: int,
stride0: int,
strdide1: int,
topk_tokens: int,
) -> torch.Tensor:
real_topk = min(topk_tokens, logits.shape[-1])
topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32)
topk_indices -= cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(
topk_indices, False, dtype=torch.bool, device=topk_indices.device
)
mask = mask_lo & mask_hi
topk_indices.masked_fill_(~mask, -1)
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices
)
@staticmethod
def top_k_per_row_decode(
logits: torch.Tensor,
next_n: int,
seq_lens: torch.Tensor,
raw_topk_indices: torch.Tensor,
num_rows: int,
stride0: int,
stride1: int,
topk_tokens: int,
) -> torch.Tensor:
device = logits.device
batch_size = seq_lens.size(0)
# padded query len
padded_num_tokens = batch_size * next_n
positions = (
torch.arange(logits.shape[-1], device=device)
.unsqueeze(0)
.expand(batch_size * next_n, -1)
)
row_indices = torch.arange(padded_num_tokens, device=device) // next_n
next_n_offset = torch.arange(padded_num_tokens, device=device) % next_n
index_end_pos = (seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float("-inf"))
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices
)
@staticmethod
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
# register all the custom ops here
direct_register_custom_op(
op_name="xpu_ops_deepseek_scaling_rope",
op_func=_xpu_ops_deepseek_scaling_rope_impl,
mutates_args=[],
fake_impl=_xpu_ops_deepseek_scaling_rope_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
xpu_ops.register_ops_once()
\ No newline at end of file
...@@ -337,9 +337,10 @@ class DynamicShapesConfig: ...@@ -337,9 +337,10 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239. until this change picked up https://github.com/pytorch/pytorch/pull/169239.
""" """
assume_32_bit_indexing: bool = True assume_32_bit_indexing: bool = False
""" """
whether all tensor sizes can use 32 bit indexing. whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
......
...@@ -259,6 +259,16 @@ class SpeculativeConfig: ...@@ -259,6 +259,16 @@ class SpeculativeConfig:
} }
) )
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe": if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp" hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp": if hf_config.model_type == "ernie_mtp":
......
...@@ -72,7 +72,7 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -72,7 +72,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer return buffer
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -96,6 +96,34 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -96,6 +96,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
) )
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
...@@ -127,7 +155,7 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -127,7 +155,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group, tcp_store_group=None): def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group) super().__init__(cpu_group, tcp_store_group)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -161,6 +189,46 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -161,6 +189,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1] return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -200,7 +268,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -200,7 +268,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -209,6 +277,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -209,6 +277,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import torch import torch
...@@ -70,13 +69,32 @@ class All2AllManagerBase: ...@@ -70,13 +69,32 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> Any: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either: # Subclasses should either:
# - implement handling for extra_tensors, or # - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported. # - raise a clear error if extra_tensors is not supported.
...@@ -312,7 +330,7 @@ class DeviceCommunicatorBase: ...@@ -312,7 +330,7 @@ class DeviceCommunicatorBase:
for module in moe_modules: for module in moe_modules:
module.maybe_init_modular_kernel() module.maybe_init_modular_kernel()
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -326,8 +344,29 @@ class DeviceCommunicatorBase: ...@@ -326,8 +344,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -338,4 +377,4 @@ class DeviceCommunicatorBase: ...@@ -338,4 +377,4 @@ class DeviceCommunicatorBase:
return hidden_states return hidden_states
def batch_isend_irecv(self, p2p_ops: list): def batch_isend_irecv(self, p2p_ops: list):
raise NotImplementedError raise NotImplementedError
\ No newline at end of file
...@@ -151,29 +151,65 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -151,29 +151,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]: ) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src) return self.dist_module.recv_tensor_dict(src)
def dispatch( # type: ignore[override] def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
is_sequence_parallel,
) )
return hidden_states
class _CPUSHMDistributed: class _CPUSHMDistributed:
......
...@@ -396,7 +396,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -396,7 +396,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch( # type: ignore[override] def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -406,20 +406,54 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -406,20 +406,54 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor] tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
is_sequence_parallel,
) )
def batch_isend_irecv(self, p2p_ops: list): def batch_isend_irecv(self, p2p_ops: list):
...@@ -427,4 +461,4 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -427,4 +461,4 @@ class CudaCommunicator(DeviceCommunicatorBase):
if pynccl_comm is not None and not pynccl_comm.disabled: if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops) pynccl_comm.batch_isend_irecv(p2p_ops)
else: else:
raise ValueError("No PyNCCL communicator found") raise ValueError("No PyNCCL communicator found")
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend from flashinfer.comm.mnnvl import CommBackend as CommBackend
......
...@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase): ...@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
) is_sequence_parallel,
return hidden_states )
\ No newline at end of file
...@@ -384,7 +384,9 @@ class TpKVTopology: ...@@ -384,7 +384,9 @@ class TpKVTopology:
@property @property
def split_k_and_v(self) -> bool: def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present). # Whether to register regions for K and V separately (when present).
return not (self.is_mla or self.is_kv_layout_blocks_first) return not (
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
)
@property @property
def tp_size(self) -> int: def tp_size(self) -> int:
...@@ -554,4 +556,4 @@ def get_current_attn_backend( ...@@ -554,4 +556,4 @@ def get_current_attn_backend(
vllm_config: VllmConfig, layer_names: list[str] | None = None vllm_config: VllmConfig, layer_names: list[str] | None = None
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Get the first attention backend for the given layers.""" """Get the first attention backend for the given layers."""
return get_current_attn_backends(vllm_config, layer_names)[0] return get_current_attn_backends(vllm_config, layer_names)[0]
\ No newline at end of file
...@@ -56,7 +56,7 @@ from vllm.logger import init_logger ...@@ -56,7 +56,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
...@@ -186,7 +186,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata): ...@@ -186,7 +186,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def compute_nixl_compatibility_hash( def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
) -> str: ) -> str:
""" """
Compute compatibility hash for NIXL KV transfer. Compute compatibility hash for NIXL KV transfer.
...@@ -1164,12 +1164,9 @@ class NixlConnectorWorker: ...@@ -1164,12 +1164,9 @@ class NixlConnectorWorker:
logger.info("Detected attention backend %s", self.backend_name) logger.info("Detected attention backend %s", self.backend_name)
logger.info("Detected kv cache layout %s", self.kv_cache_layout) logger.info("Detected kv cache layout %s", self.kv_cache_layout)
self.compat_hash = compute_nixl_compatibility_hash( # lazy initialized in register_kv_caches
self.vllm_config, self.backend_name self.compat_hash: str | None = None
) self.kv_topo: TpKVTopology | None = None
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
...@@ -1184,7 +1181,6 @@ class NixlConnectorWorker: ...@@ -1184,7 +1181,6 @@ class NixlConnectorWorker:
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True "enforce_handshake_compat", True
) )
self._physical_blocks_per_logical_kv_block = 1
def _sync_block_size_with_kernel(self) -> None: def _sync_block_size_with_kernel(self) -> None:
backends = get_current_attn_backends(self.vllm_config) backends = get_current_attn_backends(self.vllm_config)
...@@ -1232,6 +1228,7 @@ class NixlConnectorWorker: ...@@ -1232,6 +1228,7 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current # Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP, # local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i. # this happens to be the same single rank_i.
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size) p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {} remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port) path = make_zmq_path("tcp", host, port)
...@@ -1269,6 +1266,7 @@ class NixlConnectorWorker: ...@@ -1269,6 +1266,7 @@ class NixlConnectorWorker:
) )
# Check compatibility hash BEFORE decoding agent metadata # Check compatibility hash BEFORE decoding agent metadata
assert self.compat_hash is not None
if ( if (
self.enforce_compat_hash self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash and handshake_payload.compatibility_hash != self.compat_hash
...@@ -1547,7 +1545,6 @@ class NixlConnectorWorker: ...@@ -1547,7 +1545,6 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB). # (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region # Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim). # to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None tensor_size_bytes = None
# Enable different block lengths for different layers *only* when MLA is used. # Enable different block lengths for different layers *only* when MLA is used.
...@@ -1698,6 +1695,7 @@ class NixlConnectorWorker: ...@@ -1698,6 +1695,7 @@ class NixlConnectorWorker:
ssm_sizes=self._mamba_ssm_size, ssm_sizes=self._mamba_ssm_size,
) )
# Wrap metadata in payload with hash for defensive decoding # Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
encoder = msgspec.msgpack.Encoder() encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload( self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash, compatibility_hash=self.compat_hash,
...@@ -2177,6 +2175,7 @@ class NixlConnectorWorker: ...@@ -2177,6 +2175,7 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0: if len(self.device_kv_caches) == 0:
return return
assert block_size_ratio >= 1, "Only nP < nD supported currently." assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1: if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug( logger.debug(
"Post-processing device kv cache on receive by converting " "Post-processing device kv cache on receive by converting "
...@@ -2196,7 +2195,7 @@ class NixlConnectorWorker: ...@@ -2196,7 +2195,7 @@ class NixlConnectorWorker:
block_size_ratio, block_size_ratio,
) )
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first) split_k_and_v = self.kv_topo.split_k_and_v
for block_ids in block_ids_list: for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
...@@ -2221,6 +2220,7 @@ class NixlConnectorWorker: ...@@ -2221,6 +2220,7 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done. to track which workers are done.
""" """
assert self.kv_topo is not None
done_sending = self._get_new_notifs() done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers) done_recving = self._pop_done_transfers(self._recving_transfers)
...@@ -2291,6 +2291,7 @@ class NixlConnectorWorker: ...@@ -2291,6 +2291,7 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling. for all consumers to be done pulling.
""" """
assert self.kv_topo is not None
notified_req_ids: set[str] = set() notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values(): for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs: for notif in notifs:
...@@ -2451,7 +2452,7 @@ class NixlConnectorWorker: ...@@ -2451,7 +2452,7 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None assert meta.remote is not None and self.kv_topo is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id meta.remote.engine_id
) )
...@@ -2782,6 +2783,7 @@ class NixlConnectorWorker: ...@@ -2782,6 +2783,7 @@ class NixlConnectorWorker:
+-------------------+ +--------------------+ +-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split | |1st_split-2nd_split| |1st_split-2nd_split |
""" """
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first: if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part). # For indexing only half (either just the K or V part).
if mamba_view: if mamba_view:
...@@ -3103,4 +3105,4 @@ class NixlPromMetrics(KVConnectorPromMetrics): ...@@ -3103,4 +3105,4 @@ class NixlPromMetrics(KVConnectorPromMetrics):
["num_failed_transfers", "num_failed_notifications", "num_kv_expired_reqs"], ["num_failed_transfers", "num_failed_notifications", "num_kv_expired_reqs"],
): ):
for list_item in transfer_stats_data[counter_item_key]: for list_item in transfer_stats_data[counter_item_key]:
counter_obj[engine_idx].inc(list_item) counter_obj[engine_idx].inc(list_item)
\ No newline at end of file
...@@ -1065,7 +1065,7 @@ class GroupCoordinator: ...@@ -1065,7 +1065,7 @@ class GroupCoordinator:
if self.device_communicator is not None: if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model) self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -1076,7 +1076,7 @@ class GroupCoordinator: ...@@ -1076,7 +1076,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch( # type: ignore[call-arg] return self.device_communicator.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
...@@ -1085,6 +1085,28 @@ class GroupCoordinator: ...@@ -1085,6 +1085,28 @@ class GroupCoordinator:
else: else:
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states, is_sequence_parallel: bool = False self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -2090,4 +2112,4 @@ def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int: ...@@ -2090,4 +2112,4 @@ def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
if is_same_node and node_assignment[other_rank] == 0: if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id node_assignment[other_rank] = next_node_id
return next_node_id return next_node_id
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment