Commit eefa41c1 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.18.0

parent 82155c76
......@@ -266,7 +266,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
),
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"),
"ExaoneMoEForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.0.0"
"LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.1.0"
),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
......@@ -283,11 +283,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5"),
"Glm4MoeLiteForCausalLM": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0.dev",
is_available_online=False,
),
"GlmMoeDsaForCausalLM": _HfExamplesInfo(
"zai-org/GLM-5", min_transformers_version="5.0.1", is_available_online=False
min_transformers_version="5.0.0",
),
"GlmMoeDsaForCausalLM": _HfExamplesInfo(
"zai-org/GLM-5", min_transformers_version="5.0.1", is_available_online=False
......@@ -743,7 +739,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev"
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0"
),
"MusicFlamingoForConditionalGeneration": _HfExamplesInfo(
"nvidia/music-flamingo-2601-hf", min_transformers_version="5.0.0.dev"
......@@ -1237,7 +1233,13 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeLiteMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash",
speculative_model="zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0",
),
"GlmOcrMTPModel": _HfExamplesInfo(
"zai-org/GLM-OCR",
speculative_model="zai-org/GLM-OCR",
is_available_online=False,
min_transformers_version="5.1.0",
),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
......@@ -1282,27 +1284,27 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": _HfExamplesInfo(
"BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0.dev"
"BAAI/bge-base-en-v1.5", min_transformers_version="5.0.0"
),
"TransformersForSequenceClassification": _HfExamplesInfo(
"papluca/xlm-roberta-base-language-detection",
min_transformers_version="5.0.0.dev",
min_transformers_version="5.0.0",
),
"TransformersForCausalLM": _HfExamplesInfo(
"hmellor/Ilama-3.2-1B", trust_remote_code=True
),
"TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo(
"allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0.dev"
"allenai/OLMoE-1B-7B-0924", min_transformers_version="5.0.0"
),
"TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0.dev"
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="5.0.0"
),
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0.dev"
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
),
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0.dev"
"Qwen/Qwen3-30B-A3B", min_transformers_version="5.0.0"
),
"TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"),
"TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
......
......@@ -76,7 +76,7 @@ def test_models(
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
required = Version("5.0.0")
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
pytest.skip(
"MoE models with the Transformers modeling backend require "
......
......@@ -36,7 +36,7 @@ class MyGemma2Embedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the UvicornAccessLogFilter class.
"""
import logging
from vllm.logging_utils.access_log_filter import (
UvicornAccessLogFilter,
create_uvicorn_log_config,
)
class TestUvicornAccessLogFilter:
"""Test cases for UvicornAccessLogFilter."""
def test_filter_allows_all_when_no_excluded_paths(self):
"""Filter should allow all logs when no paths are excluded."""
filter = UvicornAccessLogFilter(excluded_paths=[])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/v1/completions", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is True
def test_filter_allows_all_when_excluded_paths_is_none(self):
"""Filter should allow all logs when excluded_paths is None."""
filter = UvicornAccessLogFilter(excluded_paths=None)
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is True
def test_filter_excludes_health_endpoint(self):
"""Filter should exclude /health endpoint when configured."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is False
def test_filter_excludes_metrics_endpoint(self):
"""Filter should exclude /metrics endpoint when configured."""
filter = UvicornAccessLogFilter(excluded_paths=["/metrics"])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/metrics", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is False
def test_filter_allows_non_excluded_endpoints(self):
"""Filter should allow endpoints not in the excluded list."""
filter = UvicornAccessLogFilter(excluded_paths=["/health", "/metrics"])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "POST", "/v1/completions", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is True
def test_filter_excludes_multiple_endpoints(self):
"""Filter should exclude multiple configured endpoints."""
filter = UvicornAccessLogFilter(excluded_paths=["/health", "/metrics", "/ping"])
# Test /health
record_health = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_health) is False
# Test /metrics
record_metrics = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/metrics", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_metrics) is False
# Test /ping
record_ping = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/ping", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_ping) is False
def test_filter_with_query_parameters(self):
"""Filter should exclude endpoints even with query parameters."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health?verbose=true", "1.1", 200),
exc_info=None,
)
assert filter.filter(record) is False
def test_filter_different_http_methods(self):
"""Filter should exclude endpoints regardless of HTTP method."""
filter = UvicornAccessLogFilter(excluded_paths=["/ping"])
# Test GET
record_get = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/ping", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_get) is False
# Test POST
record_post = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "POST", "/ping", "1.1", 200),
exc_info=None,
)
assert filter.filter(record_post) is False
def test_filter_with_different_status_codes(self):
"""Filter should exclude endpoints regardless of status code."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
for status_code in [200, 500, 503]:
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg='%s - "%s %s HTTP/%s" %d',
args=("127.0.0.1:12345", "GET", "/health", "1.1", status_code),
exc_info=None,
)
assert filter.filter(record) is False
class TestCreateUvicornLogConfig:
"""Test cases for create_uvicorn_log_config function."""
def test_creates_valid_config_structure(self):
"""Config should have required logging configuration keys."""
config = create_uvicorn_log_config(excluded_paths=["/health"])
assert "version" in config
assert config["version"] == 1
assert "disable_existing_loggers" in config
assert "formatters" in config
assert "handlers" in config
assert "loggers" in config
assert "filters" in config
def test_config_includes_access_log_filter(self):
"""Config should include the access log filter."""
config = create_uvicorn_log_config(excluded_paths=["/health", "/metrics"])
assert "access_log_filter" in config["filters"]
filter_config = config["filters"]["access_log_filter"]
assert filter_config["()"] == UvicornAccessLogFilter
assert filter_config["excluded_paths"] == ["/health", "/metrics"]
def test_config_applies_filter_to_access_handler(self):
"""Config should apply the filter to the access handler."""
config = create_uvicorn_log_config(excluded_paths=["/health"])
assert "access" in config["handlers"]
assert "filters" in config["handlers"]["access"]
assert "access_log_filter" in config["handlers"]["access"]["filters"]
def test_config_with_custom_log_level(self):
"""Config should respect custom log level."""
config = create_uvicorn_log_config(
excluded_paths=["/health"], log_level="debug"
)
assert config["loggers"]["uvicorn"]["level"] == "DEBUG"
assert config["loggers"]["uvicorn.access"]["level"] == "DEBUG"
assert config["loggers"]["uvicorn.error"]["level"] == "DEBUG"
def test_config_with_empty_excluded_paths(self):
"""Config should work with empty excluded paths."""
config = create_uvicorn_log_config(excluded_paths=[])
assert config["filters"]["access_log_filter"]["excluded_paths"] == []
def test_config_with_none_excluded_paths(self):
"""Config should work with None excluded paths."""
config = create_uvicorn_log_config(excluded_paths=None)
assert config["filters"]["access_log_filter"]["excluded_paths"] == []
class TestIntegration:
"""Integration tests for the access log filter."""
def test_filter_with_real_logger(self):
"""Test filter works with a real Python logger simulating uvicorn."""
# Create a logger with our filter (simulating uvicorn.access)
logger = logging.getLogger("uvicorn.access")
logger.setLevel(logging.INFO)
# Clear any existing handlers
logger.handlers = []
# Create a custom handler that tracks messages
logged_messages: list[str] = []
class TrackingHandler(logging.Handler):
def emit(self, record):
logged_messages.append(record.getMessage())
handler = TrackingHandler()
handler.setLevel(logging.INFO)
filter = UvicornAccessLogFilter(excluded_paths=["/health", "/metrics"])
handler.addFilter(filter)
logger.addHandler(handler)
# Log using uvicorn's format with args tuple
# Format: '%s - "%s %s HTTP/%s" %d'
logger.info(
'%s - "%s %s HTTP/%s" %d',
"127.0.0.1:12345",
"GET",
"/health",
"1.1",
200,
)
logger.info(
'%s - "%s %s HTTP/%s" %d',
"127.0.0.1:12345",
"GET",
"/v1/completions",
"1.1",
200,
)
logger.info(
'%s - "%s %s HTTP/%s" %d',
"127.0.0.1:12345",
"GET",
"/metrics",
"1.1",
200,
)
logger.info(
'%s - "%s %s HTTP/%s" %d',
"127.0.0.1:12345",
"POST",
"/v1/chat/completions",
"1.1",
200,
)
# Verify only non-excluded endpoints were logged
assert len(logged_messages) == 2
assert "/v1/completions" in logged_messages[0]
assert "/v1/chat/completions" in logged_messages[1]
def test_filter_allows_non_uvicorn_access_logs(self):
"""Test filter allows logs from non-uvicorn.access loggers."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
# Log record from a different logger name
record = logging.LogRecord(
name="uvicorn.error",
level=logging.INFO,
pathname="",
lineno=0,
msg="Some error message about /health",
args=(),
exc_info=None,
)
# Should allow because it's not from uvicorn.access
assert filter.filter(record) is True
def test_filter_handles_malformed_args(self):
"""Test filter handles log records with unexpected args format."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
# Log record with insufficient args
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg="Some message",
args=("only", "two"),
exc_info=None,
)
# Should allow because args doesn't have expected format
assert filter.filter(record) is True
def test_filter_handles_non_tuple_args(self):
"""Test filter handles log records with non-tuple args."""
filter = UvicornAccessLogFilter(excluded_paths=["/health"])
# Log record with None args
record = logging.LogRecord(
name="uvicorn.access",
level=logging.INFO,
pathname="",
lineno=0,
msg="Some message without args",
args=None,
exc_info=None,
)
# Should allow because args is None
assert filter.filter(record) is True
\ No newline at end of file
......@@ -383,7 +383,7 @@ def _run_eagle_correctness(
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
required = Version("5.0.0")
if installed < required:
pytest.skip(
"Eagle3 with the Transformers modeling backend requires "
......
......@@ -59,9 +59,9 @@ fi
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
fi
# Models to run
......
......@@ -18,8 +18,12 @@ import ray
import torch
from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
......@@ -58,6 +62,8 @@ from vllm.v1.kv_cache_interface import (
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup
from .utils import (
create_request,
......@@ -1498,44 +1504,6 @@ def test_register_kv_caches(
backend_cls = TritonAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
......@@ -1716,14 +1684,13 @@ def test_register_kv_caches(
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, (
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
if connector.prefer_cross_layer_blocks:
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
else:
num_blocks = kv_cache_config.num_blocks
if is_blocks_first:
......@@ -2360,7 +2327,9 @@ def test_compatibility_hash_validation(
)
)
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, decode_worker.backend_name
remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
)
prefill_block_size = config_overrides.get("block_size", 16)
......
......@@ -3044,13 +3044,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
class CPUDNNLGEMMHandler:
def __init__(self) -> None:
self.handler: int | None = None
self.handler_tensor: torch.Tensor | None = None
self.n = -1
self.k = -1
def __del__(self):
if self.handler is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler)
if self.handler_tensor is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler_tensor.item())
_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
......@@ -3066,8 +3066,10 @@ def create_onednn_mm(
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_mm_handler(
weight, primitive_cache_size
# store the handler pointer in a tensor it doesn't get inlined
handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_mm_handler(weight, primitive_cache_size),
dtype=torch.int64,
)
return handler
......@@ -3079,7 +3081,7 @@ def onednn_mm(
) -> torch.Tensor:
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
torch.ops._C.onednn_mm(
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor
)
return output
......@@ -3095,8 +3097,17 @@ def create_onednn_scaled_mm(
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_scaled_mm_handler(
weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size
# store the handler pointer in a tensor so it doesn't get inlined
handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_scaled_mm_handler(
weight,
weight_scales,
output_type,
dynamic_quant,
use_azp,
primitive_cache_size,
),
dtype=torch.int64,
)
return handler
......@@ -3149,11 +3160,15 @@ def onednn_scaled_mm(
bias: torch.Tensor | None,
) -> torch.Tensor:
torch.ops._C.onednn_scaled_mm(
output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler
output,
x,
input_scale,
input_zp,
input_zp_adj,
bias,
dnnl_handler.handler_tensor,
)
return output
def cpu_attn_get_scheduler_metadata(
num_reqs: int,
......
......@@ -7,6 +7,8 @@ import torch
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
......@@ -53,6 +55,37 @@ if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
return torch.empty((M, N), dtype=input.dtype, device=input.device)
def _xpu_ops_deepseek_scaling_rope_impl(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
assert key is not None
return torch.ops._xpu_C.deepseek_scaling_rope(
positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style
)
def _xpu_ops_deepseek_scaling_rope_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return query, key
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
class xpu_ops:
@staticmethod
def flash_attn_varlen_func(
......@@ -105,9 +138,10 @@ class xpu_ops:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1]) # noqa: F841
# In encode attention, v maybe not contiguous and current
# In encode attention, k and v maybe not contiguous and current
# kernel can't handle it
if block_table is None:
k = k.contiguous()
v = v.contiguous()
return flash_attn_varlen_func(
out=out,
......@@ -156,3 +190,265 @@ class xpu_ops:
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
)
return None
@staticmethod
def indexer_k_quant_and_cache(
k: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
) -> None:
head_dim = k.shape[-1]
k = k.view(-1, head_dim) # [total_tokens, head_dim]
def group_quant_torch(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype | None = None,
column_major_scales: bool = False,
out_q: torch.Tensor | None = None,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if use_ue8m0 is None:
# Default fallback - could import is_deep_gemm_e8m0_used if needed
use_ue8m0 = False
if dtype is None:
dtype = current_platform.fp8_dtype()
# Validate inputs
assert x.shape[-1] % group_size == 0, (
f"Last dimension {x.shape[-1]} must be divisible by "
f"group_size {group_size}"
)
assert x.stride(-1) == 1, "Input tensor groups must be contiguous"
# Prepare output tensor
if out_q is None:
x_q = torch.empty_like(x, dtype=dtype)
else:
assert out_q.shape == x.shape
x_q = out_q
# Reshape input for group processing
# Original shape: (..., last_dim)
# Target shape: (..., num_groups, group_size)
original_shape = x.shape
num_groups = original_shape[-1] // group_size
# Reshape to separate groups
group_shape = original_shape[:-1] + (num_groups, group_size)
x_grouped = x.view(group_shape)
# Compute per-group absolute maximum values
# Shape: (..., num_groups)
abs_max = torch.amax(torch.abs(x_grouped), dim=-1, keepdim=False)
abs_max = torch.maximum(
abs_max, torch.tensor(eps, device=x.device, dtype=x.dtype)
)
# Compute scales
FP8_MAX = torch.finfo(dtype).max
FP8_MIN = torch.finfo(dtype).min
scale_raw = abs_max / FP8_MAX
if use_ue8m0:
# For UE8M0 format, scales must be powers of 2
scales = torch.pow(2.0, torch.ceil(torch.log2(scale_raw)))
else:
scales = scale_raw
# Expand scales for broadcasting with grouped data
# Shape: (..., num_groups, 1)
scales_expanded = scales.unsqueeze(-1)
# Quantize the grouped data
x_scaled = x_grouped / scales_expanded
x_clamped = torch.clamp(x_scaled, FP8_MIN, FP8_MAX)
x_quantized = x_clamped.to(dtype)
# Reshape back to original shape
x_q.copy_(x_quantized.view(original_shape))
# Prepare scales tensor in requested format
if column_major_scales:
# Column-major: (num_groups,) + batch_dims
# Transpose the scales to put group dimension first
scales_shape = (num_groups,) + original_shape[:-1]
x_s = scales.permute(-1, *range(len(original_shape) - 1))
x_s = x_s.contiguous().view(scales_shape)
else:
# Row-major: batch_dims + (num_groups,)
x_s = scales.contiguous()
# Ensure scales are float32
return x_q, x_s.float()
k_fp8, k_scale = group_quant_torch(
k,
group_size=quant_block_size,
column_major_scales=False,
use_ue8m0=(scale_fmt == "ue8m0"),
)
k_fp8_bytes = k_fp8.view(-1, head_dim).view(torch.uint8)
scale_bytes = k_scale.view(torch.uint8).view(-1, 4)
k = torch.cat(
[k_fp8_bytes, scale_bytes], dim=-1
) # [total_tokens, head_dim + 4]
slot_mapping = slot_mapping.flatten()
# kv_cache: [num_block, block_size, head_dim + 4]
kv_cache.view(-1, kv_cache.shape[-1]).index_copy_(0, slot_mapping, k)
@staticmethod
def cp_gather_indexer_k_quant_cache(
kv_cache: torch.Tensor,
dst_k: torch.Tensor,
dst_scale: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
) -> None:
"""
Args:
kv_cache: [num_blocks, block_size, cache_stride] - quantized KV cache
Layout per block: [k_values, scale_values]
- k_values: [block_size * head_dim]
- scale_values: [block_size * head_dim * 4 / quant_block_size]
dst_k: [num_tokens, head_dim] - output tensor for K values
dst_scale: [num_tokens, head_dim / quant_block_size * 4]
- output tensor for scale values
block_table: [batch_size, num_blocks] - block table for indexing
cu_seq_lens: [batch_size + 1] - cumulative sequence lengths
"""
batch_size = block_table.size(0)
num_tokens = dst_k.size(0)
head_dim = dst_k.size(1)
cache_block_size = kv_cache.size(1)
quant_block_size = head_dim * 4 // dst_scale.size(1)
# For each token, find which batch it belongs to using searchsorted
token_indices = torch.arange(num_tokens, device=dst_k.device) + 1
# cu_seq_lens is [batch_size + 1], we need to find which interval each
# token belongs to
batch_indices = torch.searchsorted(cu_seq_lens, token_indices) - 1
batch_indices = torch.clamp(batch_indices, 0, batch_size - 1)
# Calculate the in-batch sequence index for each token
inbatch_seq_indices = token_indices - cu_seq_lens[batch_indices]
# Find which block each token belongs to
block_indices_in_table = inbatch_seq_indices // cache_block_size
physical_block_indices = block_table[batch_indices, block_indices_in_table]
# Calculate the offset within each block
inblock_offsets = (inbatch_seq_indices - 1) % cache_block_size
# Calculate strides
block_stride = kv_cache.stride(0) # stride for each block
# Flatten kv_cache for easier indexing
kv_cache_flat = kv_cache.view(-1)
# Calculate source offset for K values for all tokens (vectorized)
src_block_offsets = physical_block_indices * block_stride
src_k_offsets = src_block_offsets + inblock_offsets * head_dim
# Gather K values using advanced indexing
# Create indices for all elements we need to gather
k_indices = src_k_offsets.unsqueeze(1) + torch.arange(
head_dim, device=dst_k.device
)
dst_k[:] = kv_cache_flat[k_indices]
# Calculate source offset for scale values (vectorized)
# Scales are stored after all K values for each block
scale_size = head_dim * 4 // quant_block_size
src_scale_offsets = src_block_offsets + head_dim + inblock_offsets * scale_size
# Gather scale values
scale_indices = src_scale_offsets.unsqueeze(1) + torch.arange(
scale_size, device=dst_scale.device
)
dst_scale[:] = kv_cache_flat[scale_indices]
@staticmethod
def top_k_per_row_prefill(
logits: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
raw_topk_indices: torch.Tensor,
num_rows: int,
stride0: int,
strdide1: int,
topk_tokens: int,
) -> torch.Tensor:
real_topk = min(topk_tokens, logits.shape[-1])
topk_indices = logits.topk(real_topk, dim=-1)[1].to(torch.int32)
topk_indices -= cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(
topk_indices, False, dtype=torch.bool, device=topk_indices.device
)
mask = mask_lo & mask_hi
topk_indices.masked_fill_(~mask, -1)
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices
)
@staticmethod
def top_k_per_row_decode(
logits: torch.Tensor,
next_n: int,
seq_lens: torch.Tensor,
raw_topk_indices: torch.Tensor,
num_rows: int,
stride0: int,
stride1: int,
topk_tokens: int,
) -> torch.Tensor:
device = logits.device
batch_size = seq_lens.size(0)
# padded query len
padded_num_tokens = batch_size * next_n
positions = (
torch.arange(logits.shape[-1], device=device)
.unsqueeze(0)
.expand(batch_size * next_n, -1)
)
row_indices = torch.arange(padded_num_tokens, device=device) // next_n
next_n_offset = torch.arange(padded_num_tokens, device=device) % next_n
index_end_pos = (seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float("-inf"))
topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices
)
@staticmethod
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
# register all the custom ops here
direct_register_custom_op(
op_name="xpu_ops_deepseek_scaling_rope",
op_func=_xpu_ops_deepseek_scaling_rope_impl,
mutates_args=[],
fake_impl=_xpu_ops_deepseek_scaling_rope_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
xpu_ops.register_ops_once()
\ No newline at end of file
......@@ -337,9 +337,10 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
assume_32_bit_indexing: bool = True
assume_32_bit_indexing: bool = False
"""
whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
"""
def compute_hash(self) -> str:
......
......@@ -259,6 +259,16 @@ class SpeculativeConfig:
}
)
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
......
......@@ -72,7 +72,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -97,6 +97,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -127,7 +155,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -161,6 +189,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -200,7 +268,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -209,6 +277,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
from weakref import WeakValueDictionary
import torch
......@@ -70,13 +69,32 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
......@@ -312,7 +330,7 @@ class DeviceCommunicatorBase:
for module in moe_modules:
module.maybe_init_modular_kernel()
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -326,8 +344,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
......@@ -151,29 +151,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src)
def dispatch( # type: ignore[override]
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
return hidden_states
class _CPUSHMDistributed:
......
......@@ -396,7 +396,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch( # type: ignore[override]
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -406,20 +406,54 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
def batch_isend_irecv(self, p2p_ops: list):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
......
......@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
\ No newline at end of file
return hidden_states
......@@ -384,7 +384,9 @@ class TpKVTopology:
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (self.is_mla or self.is_kv_layout_blocks_first)
return not (
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
)
@property
def tp_size(self) -> int:
......
......@@ -56,7 +56,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (
......@@ -186,7 +186,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str
vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
......@@ -1164,12 +1164,9 @@ class NixlConnectorWorker:
logger.info("Detected attention backend %s", self.backend_name)
logger.info("Detected kv cache layout %s", self.kv_cache_layout)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name
)
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
# lazy initialized in register_kv_caches
self.compat_hash: str | None = None
self.kv_topo: TpKVTopology | None = None
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
......@@ -1184,7 +1181,6 @@ class NixlConnectorWorker:
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
self._physical_blocks_per_logical_kv_block = 1
def _sync_block_size_with_kernel(self) -> None:
backends = get_current_attn_backends(self.vllm_config)
......@@ -1232,6 +1228,7 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port)
......@@ -1269,6 +1266,7 @@ class NixlConnectorWorker:
)
# Check compatibility hash BEFORE decoding agent metadata
assert self.compat_hash is not None
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
......@@ -1547,7 +1545,6 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
# Enable different block lengths for different layers *only* when MLA is used.
......@@ -1698,6 +1695,7 @@ class NixlConnectorWorker:
ssm_sizes=self._mamba_ssm_size,
)
# Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash,
......@@ -2177,6 +2175,7 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0:
return
assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
......@@ -2196,7 +2195,7 @@ class NixlConnectorWorker:
block_size_ratio,
)
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
split_k_and_v = self.kv_topo.split_k_and_v
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
......@@ -2221,6 +2220,7 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
assert self.kv_topo is not None
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
......@@ -2291,6 +2291,7 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
assert self.kv_topo is not None
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
......@@ -2451,7 +2452,7 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None
assert meta.remote is not None and self.kv_topo is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id
)
......@@ -2782,6 +2783,7 @@ class NixlConnectorWorker:
+-------------------+ +--------------------+
|1st_split-2nd_split| |1st_split-2nd_split |
"""
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part).
if mamba_view:
......
......@@ -1065,7 +1065,7 @@ class GroupCoordinator:
if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -1076,7 +1076,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch( # type: ignore[call-arg]
return self.device_communicator.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
......@@ -1085,6 +1085,28 @@ class GroupCoordinator:
else:
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment