Unverified Commit f44afef6 authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[compile] Allow strings in custom ops without regressing compilation times (#38123)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent 447ce222
......@@ -28,6 +28,7 @@ from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _encode_layer_name
from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
......@@ -170,7 +171,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size)
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
k, v, self.layer_name
k, v, _encode_layer_name(self.layer_name)
)
return q, k, v, kv_cache_dummy_dep
......
......@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import _USE_LAYERNAME, _encode_layer_name
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
from .matcher_utils import MatcherQuantFP8
......@@ -53,21 +54,43 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@property
def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
# When _USE_LAYERNAME is enabled (torch >= 2.11), layer_name is
# passed as an explicit pattern input so the pattern matcher
# treats it as a wildcard matching hoisted LayerName placeholders.
# Otherwise it stays as a closure constant (original behavior).
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _pattern_with_ln( # type: ignore[misc]
q, k, v, output_attn, scale, kv_cache_dummy_dep, layer_name
):
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self._num_heads * self._head_size]
)
return self._quant_matcher(attn_out_view, scale)[0]
return _pattern_with_ln
def _pattern(q, k, v, output_attn, scale, kv_cache_dummy_dep):
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self._layer_name,
layer_name=_ln,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
......@@ -81,14 +104,34 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@property
def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _replacement_with_ln( # type: ignore[misc]
q, k, v, output_attn, scale, kv_cache_dummy_dep, layer_name
):
output_attn = torch.empty(
[q.shape[0], self._num_heads, self._head_size],
dtype=FP8_DTYPE,
device=q.device,
)
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return RESHAPE_OP(at1[1], [-1, self._num_heads * self._head_size])
return _replacement_with_ln
def _replacement(q, k, v, output_attn, scale, kv_cache_dummy_dep):
output_attn = torch.empty(
[q.shape[0], self._num_heads, self._head_size],
dtype=FP8_DTYPE,
......@@ -100,7 +143,7 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
key=k,
value=v,
output=output_attn,
layer_name=self._layer_name,
layer_name=_ln,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
......@@ -113,7 +156,7 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
dtype = self._dtype
num_heads = self._num_heads
head_size = self._head_size
return [
inputs: list = [
self.empty(5, num_heads, head_size, dtype=dtype), # q
self.empty(5, num_heads, head_size, dtype=dtype), # k
self.empty(5, num_heads, head_size, dtype=dtype), # v
......@@ -121,6 +164,9 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
self.empty_fp32(1, 1), # scale
self.empty(0, dtype=dtype), # kv_cache_dummy_dep
]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class AttnNvfp4QuantPattern(
......@@ -144,23 +190,64 @@ class AttnNvfp4QuantPattern(
@property
def pattern(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _pattern_with_ln( # type: ignore[misc]
q,
k,
v,
output_attn,
output_quant,
output_scale,
input_scale,
kv_cache_dummy_dep,
layer_name,
):
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self._num_heads * self._head_size]
)
at2 = auto_functionalized(
self._QUANT_OP,
input=attn_out_view,
input_scale=input_scale,
is_sf_swizzled_layout=True,
output=output_quant,
output_scale=output_scale,
)
return at2[1], torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return _pattern_with_ln
def _pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
q,
k,
v,
output_attn,
output_quant,
output_scale,
input_scale,
kv_cache_dummy_dep,
):
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self._layer_name,
layer_name=_ln,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
......@@ -176,42 +263,80 @@ class AttnNvfp4QuantPattern(
output=output_quant,
output_scale=output_scale,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
return at2[1], torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return _pattern
@property
def replacement(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _replacement_with_ln( # type: ignore[misc]
q,
k,
v,
output_attn,
_output_quant,
output_scale,
input_scale,
kv_cache_dummy_dep,
layer_name,
):
output_attn = torch.empty(
[q.shape[0], self._num_heads, self._head_size // 2],
dtype=FP4_DTYPE,
device=q.device,
)
osv = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=layer_name,
output_scale=input_scale,
output_block_scale=osv,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return RESHAPE_OP(
at2[1], [-1, self._num_heads * self._head_size // 2]
), at2[2]
return _replacement_with_ln
def _replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
_output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
q,
k,
v,
output_attn,
_output_quant,
output_scale,
input_scale,
kv_cache_dummy_dep,
):
output_attn = torch.empty(
[q.shape[0], self._num_heads, self._head_size // 2],
dtype=FP4_DTYPE,
device=q.device,
)
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
osv = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self._layer_name,
layer_name=_ln,
output_scale=input_scale,
output_block_scale=output_scale_view,
output_block_scale=osv,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output = RESHAPE_OP(at2[1], [-1, self._num_heads * self._head_size // 2])
return output, at2[2]
return RESHAPE_OP(
at2[1], [-1, self._num_heads * self._head_size // 2]
), at2[2]
return _replacement
......@@ -219,18 +344,19 @@ class AttnNvfp4QuantPattern(
dtype = self._dtype
num_heads = self._num_heads
head_size = self._head_size
return [
inputs: list = [
self.empty_bf16(5, num_heads, head_size), # q
self.empty_bf16(5, num_heads, head_size), # k
self.empty_bf16(5, num_heads, head_size), # v
self.empty_bf16(5, num_heads, head_size), # output_attn
self.empty(5, num_heads * head_size // 2, dtype=FP4_DTYPE), # output_quant
self.empty_i32(
128, round_up(num_heads * head_size // 16, 4)
), # output_scale
self.empty(5, num_heads * head_size // 2, dtype=FP4_DTYPE),
self.empty_i32(128, round_up(num_heads * head_size // 16, 4)),
self.empty_fp32(1, 1), # input_scale
self.empty(0, dtype=dtype), # kv_cache_dummy_dep
]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class AttnQuantFusionPass(VllmFusionPatternMatcherPass):
......@@ -259,13 +385,19 @@ class AttnQuantFusionPass(VllmFusionPatternMatcherPass):
"so no fusion patterns were registered."
)
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for layer in layers:
if layer.impl.fused_output_quant_supported(_FP8_QUANT_KEY):
self.register(AttnFp8StaticQuantPattern(layer, dtype))
if _USE_LAYERNAME:
break
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
for layer in layers:
if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
self.register(AttnNvfp4QuantPattern(layer, dtype))
if _USE_LAYERNAME:
break
self.dump_patterns(config, self.pm_pass)
......@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _USE_LAYERNAME, _encode_layer_name
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
from .matcher_utils import MatcherQuantFP8
......@@ -49,21 +50,43 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@property
def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _pattern_with_ln( # type: ignore[misc]
q,
kv_c_normed,
k_pe,
output_attn,
scale,
kv_cache_dummy_dep,
layer_name,
):
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
# MLA output is already 2D (T, N*V), no reshape needed
return self._quant_matcher(at1[1], scale)[0]
return _pattern_with_ln
def _pattern(q, kv_c_normed, k_pe, output_attn, scale, kv_cache_dummy_dep):
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=self._layer_name,
layer_name=_ln,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
......@@ -75,14 +98,41 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@property
def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _replacement_with_ln( # type: ignore[misc]
q,
kv_c_normed,
k_pe,
output_attn,
scale,
kv_cache_dummy_dep,
layer_name,
):
# MLA output in quant_dtype
output_attn = torch.empty(
[q.shape[0], self._output_dim],
dtype=FP8_DTYPE,
device=q.device,
)
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return at1[1]
return _replacement_with_ln
def _replacement(q, kv_c_normed, k_pe, output_attn, scale, kv_cache_dummy_dep):
# MLA output in quant_dtype
output_attn = torch.empty(
[q.shape[0], self._output_dim],
......@@ -95,7 +145,7 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=self._layer_name,
layer_name=_ln,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
......@@ -105,7 +155,7 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
return _replacement
def get_inputs(self) -> list[torch.Tensor]:
return [
inputs: list = [
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
self.empty(5, self._kv_lora_rank, dtype=self._dtype),
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
......@@ -113,6 +163,9 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
self.empty_fp32(1, 1),
self.empty(0, dtype=self._dtype),
]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class MLAAttnNvfp4QuantPattern(
......@@ -141,21 +194,56 @@ class MLAAttnNvfp4QuantPattern(
def pattern(
self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _pattern_with_ln( # type: ignore[misc]
q,
kv_c_normed,
k_pe,
output_attn,
input_scale,
kv_cache_dummy_dep,
layer_name,
):
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output_quant, output_scale = create_fp4_output_tensors(
at1[1].shape[0], at1[1].shape[1], at1[1].device, True
)
at2 = auto_functionalized(
self._QUANT_OP,
input=at1[1],
input_scale=input_scale,
is_sf_swizzled_layout=True,
output=output_quant,
output_scale=output_scale,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
return _pattern_with_ln
def _pattern(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
q, kv_c_normed, k_pe, output_attn, input_scale, kv_cache_dummy_dep
):
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=self._layer_name,
layer_name=_ln,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
......@@ -182,14 +270,47 @@ class MLAAttnNvfp4QuantPattern(
def replacement(
self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _replacement_with_ln( # type: ignore[misc]
q,
kv_c_normed,
k_pe,
output_attn,
input_scale,
kv_cache_dummy_dep,
layer_name,
):
# MLA output in quant_dtype (FP4 packed as uint8)
output_attn = torch.empty(
[q.shape[0], self._output_dim // 2],
dtype=FP4_DTYPE,
device=q.device,
)
output_scale = create_fp4_output_tensors(
q.shape[0], self._output_dim, q.device, True
)[1]
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return at2[1], at2[2]
return _replacement_with_ln
def _replacement(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_attn: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
q, kv_c_normed, k_pe, output_attn, input_scale, kv_cache_dummy_dep
):
# MLA output in quant_dtype (FP4 packed as uint8)
output_attn = torch.empty(
[q.shape[0], self._output_dim // 2],
......@@ -207,7 +328,7 @@ class MLAAttnNvfp4QuantPattern(
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=self._layer_name,
layer_name=_ln,
output_scale=input_scale,
output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep,
......@@ -217,7 +338,7 @@ class MLAAttnNvfp4QuantPattern(
return _replacement
def get_inputs(self) -> list[torch.Tensor]:
return [
inputs: list = [
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
self.empty(5, self._kv_lora_rank, dtype=self._dtype),
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
......@@ -225,6 +346,9 @@ class MLAAttnNvfp4QuantPattern(
self.empty_fp32(1, 1),
self.empty(0, dtype=self._dtype),
]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
......@@ -250,13 +374,19 @@ class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
"so no fusion patterns were registered."
)
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for layer in layers:
if layer.impl.fused_output_quant_supported(kFp8StaticTensorSym):
self.register(MLAAttnFp8StaticQuantPattern(layer, dtype))
if _USE_LAYERNAME:
break
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
for layer in layers:
if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
self.register(MLAAttnNvfp4QuantPattern(layer, dtype))
if _USE_LAYERNAME:
break
self.dump_patterns(config, self.pm_pass)
......@@ -15,7 +15,13 @@ from vllm.model_executor.layers.attention.attention import (
Attention,
get_attention_context,
)
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
_USE_LAYERNAME,
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
......@@ -37,7 +43,7 @@ def fused_rope_and_unified_kv_cache_update_impl(
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
layer_name: LayerNameType,
) -> torch.Tensor:
"""
This impl fetches the KV cache and slot mapping from the forward context,
......@@ -46,6 +52,7 @@ def fused_rope_and_unified_kv_cache_update_impl(
that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
layer_name = _resolve_layer_name(layer_name)
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
attn_layer.impl.do_rope_and_kv_cache_update(
......@@ -70,7 +77,7 @@ def fused_rope_and_unified_kv_cache_update_fake(
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
layer_name: LayerNameType,
) -> torch.Tensor:
return torch.empty(0, device=query.device, dtype=query.dtype)
......@@ -120,38 +127,30 @@ class RopeReshapeKVCachePattern:
num_kv_heads=self.num_kv_heads,
)
def get_inputs(self) -> list[torch.Tensor]:
def get_inputs(self) -> list:
# Sample inputs to help pattern tracing
T = 5
L = 4096
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
positions = empty_i64(T)
cos_sin_cache = empty_bf16(L, self.head_size)
return [
qkv,
positions,
cos_sin_cache,
]
inputs: list = [qkv, positions, cos_sin_cache]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self.layer_name))
return inputs
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def _mk_pattern_with_layer_name_input(self, _ln):
"""Pattern/replacement with layer_name as an explicit input."""
def pattern(qkv, positions, cos_sin_cache, layer_name):
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
return dummy, q, k, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return torch.ops.vllm.unified_kv_cache_update(k, v, layer_name), q, k, v
def replacement(qkv, positions, cos_sin_cache, layer_name):
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
......@@ -164,10 +163,50 @@ class RopeReshapeKVCachePattern:
positions=positions,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
layer_name=self.layer_name,
layer_name=layer_name,
)
return results[0], results[1], results[2], v
return pattern, replacement
def _mk_pattern_with_layer_name_closure(self, _ln):
"""Pattern/replacement with layer_name as a closure constant."""
def pattern(qkv, positions, cos_sin_cache):
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
return torch.ops.vllm.unified_kv_cache_update(k, v, _ln), q, k, v
def replacement(qkv, positions, cos_sin_cache):
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
results = auto_functionalized(
self.FUSED_OP,
query=q,
key=k,
value=v,
positions=positions,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
layer_name=_ln,
)
return results[0], results[1], results[2], v
return pattern, replacement
def register(self, pm_pass: PatternMatcherPass) -> None:
_ln = _encode_layer_name(self.layer_name)
if _USE_LAYERNAME:
pattern, replacement = self._mk_pattern_with_layer_name_input(_ln)
else:
pattern, replacement = self._mk_pattern_with_layer_name_closure(_ln)
# NOTE: use view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
......@@ -176,7 +215,11 @@ class RopeReshapeKVCachePattern:
return gm
pm.register_replacement(
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
pattern,
replacement,
self.get_inputs(),
fwd_and_view_to_reshape,
pm_pass,
)
......@@ -205,6 +248,8 @@ class RopeKVCacheFusionPass(VllmPatternMatcherPass):
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
attn_layers = get_layers_from_vllm_config(config, Attention)
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for _, layer in attn_layers.items():
if layer.impl.fused_rope_kvcache_supported():
for is_neox in [True, False]:
......@@ -212,6 +257,8 @@ class RopeKVCacheFusionPass(VllmPatternMatcherPass):
layer=layer,
is_neox=is_neox,
).register(self.patterns)
if _USE_LAYERNAME:
break
self.dump_patterns(config, self.patterns)
......
......@@ -129,6 +129,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
VLLM_USE_LAYERNAME: bool = True
Q_SCALE_CONSTANT: int = 200
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
......@@ -1090,6 +1091,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")
),
"VLLM_DISABLE_COMPILE_CACHE": disable_compile_cache,
# If set to "0", disable LayerName opaque type for layer_name
# parameters in custom ops. Defaults to enabled on torch >= 2.11.
"VLLM_USE_LAYERNAME": lambda: bool(int(os.getenv("VLLM_USE_LAYERNAME", "1"))),
# If set, vllm will run in development mode, which will enable
# some additional endpoints for developing and debugging,
# e.g. `/reset_prefix_cache`
......
......@@ -25,6 +25,9 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
......@@ -414,7 +417,9 @@ class Attention(nn.Module, AttentionLayerBase):
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
torch.ops.vllm.maybe_calc_kv_scales(
query, key, value, _encode_layer_name(self.layer_name)
)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
......@@ -466,6 +471,7 @@ class Attention(nn.Module, AttentionLayerBase):
)
else:
# Skip this if sharing KV cache with an earlier attention layer.
encoded = _encode_layer_name(self.layer_name)
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
......@@ -473,14 +479,14 @@ class Attention(nn.Module, AttentionLayerBase):
and value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name
key, value, encoded
)
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
encoded,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output.view(-1, hidden_size)
......@@ -553,8 +559,9 @@ def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
......@@ -570,7 +577,7 @@ def maybe_calc_kv_scales_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
return
......@@ -622,12 +629,13 @@ def get_attention_context(
def unified_kv_cache_update(
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> torch.Tensor:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
layer_name = _resolve_layer_name(layer_name)
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
......@@ -647,7 +655,7 @@ def unified_kv_cache_update(
def unified_kv_cache_update_fake(
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> torch.Tensor:
return torch.empty(0, device=key.device, dtype=key.dtype)
......@@ -666,7 +674,7 @@ def unified_attention_with_output(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
......@@ -675,6 +683,7 @@ def unified_attention_with_output(
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
layer_name = _resolve_layer_name(layer_name)
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
self.impl.forward(
......@@ -695,7 +704,7 @@ def unified_attention_with_output_fake(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
......
......@@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
from vllm.utils.torch_utils import _resolve_layer_name
def maybe_transfer_kv_layer(func: Callable) -> Callable:
......@@ -38,7 +39,7 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return func(*args, **kwargs)
layer_name: str = args[layer_name_index]
layer_name = _resolve_layer_name(args[layer_name_index])
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
......
......@@ -240,6 +240,9 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype,
......@@ -473,7 +476,12 @@ class MLAAttention(nn.Module, AttentionLayerBase):
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
torch.ops.vllm.maybe_calc_kv_scales(
q,
kv_c_normed,
k_pe,
_encode_layer_name(self.layer_name),
)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
......@@ -505,10 +513,11 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
return output
else:
encoded = _encode_layer_name(self.layer_name)
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
kv_c_normed,
k_pe,
self.layer_name,
encoded,
self.kv_cache_dtype,
self._k_scale,
)
......@@ -518,7 +527,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_c_normed,
k_pe,
output,
self.layer_name,
encoded,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output
......@@ -900,7 +909,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
def unified_mla_kv_cache_update(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> torch.Tensor:
......@@ -908,6 +917,7 @@ def unified_mla_kv_cache_update(
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
layer_name = _resolve_layer_name(layer_name)
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache
......@@ -939,7 +949,7 @@ def unified_mla_kv_cache_update(
def unified_mla_kv_cache_update_fake(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> torch.Tensor:
......@@ -959,7 +969,7 @@ def unified_mla_attention_with_output(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
......@@ -968,6 +978,7 @@ def unified_mla_attention_with_output(
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
layer_name = _resolve_layer_name(layer_name)
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
layer.forward_impl(
q,
......@@ -986,7 +997,7 @@ def unified_mla_attention_with_output_fake(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
......
......@@ -10,7 +10,12 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.attention import Attention
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
......@@ -170,7 +175,9 @@ class StaticSinkAttention(Attention, CustomOp):
)
if not self.sink_populated:
self_kv_cache = self.kv_cache
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
torch.ops.vllm.maybe_populate_sink(
self_kv_cache, _encode_layer_name(self.layer_name)
)
return super().forward(query, key, value, output_shape)
......@@ -224,8 +231,9 @@ class StaticSinkAttention(Attention, CustomOp):
def maybe_populate_sink(
self_kv_cache: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if self.sink_populated or self_kv_cache.numel() == 0:
......@@ -235,7 +243,7 @@ def maybe_populate_sink(
def maybe_populate_sink_fake(
self_kv_cache: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
return
......
......@@ -32,15 +32,15 @@ from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
HAS_OPAQUE_TYPE,
ModuleName,
_USE_LAYERNAME,
LayerName,
direct_register_custom_op,
)
def get_layer_from_name(layer_name: str) -> torch.nn.Module:
forward_context: ForwardContext = get_forward_context()
if not HAS_OPAQUE_TYPE and layer_name == "from_forward_context":
if not _USE_LAYERNAME and layer_name == "from_forward_context":
all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
......@@ -55,21 +55,21 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
return forward_context.no_compile_layers[layer_name]
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
# On torch >= 2.11, layer_name is a hoisted LayerName opaque object;
# on older versions it remains a plain str.
if TYPE_CHECKING:
from typing import TypeAlias
_layer_name_type: TypeAlias = str | ModuleName
_layer_name_type: TypeAlias = str | LayerName
else:
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
_layer_name_type = LayerName if _USE_LAYERNAME else str
@torch.compiler.assume_constant_result
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
def _resolve_layer_name(layer_name: str | LayerName) -> str:
from torch._library.fake_class_registry import FakeScriptObject
if isinstance(layer_name, ModuleName):
if isinstance(layer_name, LayerName):
return layer_name.value
elif isinstance(layer_name, FakeScriptObject):
return layer_name.real_obj.value
......@@ -331,9 +331,9 @@ class MoERunnerBase(MoERunner):
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])
def _encode_layer_name(self) -> str | ModuleName:
if HAS_OPAQUE_TYPE:
return ModuleName(self.layer_name)
def _encode_layer_name(self) -> str | LayerName:
if _USE_LAYERNAME:
return LayerName(self.layer_name)
# Can be unavailable or None in unittests
if (
is_forward_context_available()
......
......@@ -56,7 +56,12 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
......@@ -568,7 +573,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
b,
a,
core_attn_out,
self.prefix,
_encode_layer_name(self.prefix),
)
# ============================================================
......@@ -1084,13 +1089,14 @@ def gdn_attention_core(
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
"""
Custom op for the core attention computation.
Only handles the convolution + recurrent attention part.
Input/output projections are handled outside this op.
"""
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward_core(
......@@ -1106,7 +1112,7 @@ def gdn_attention_core_fake(
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
"""Fake implementation for torch.compile."""
return
......
......@@ -36,7 +36,12 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
......@@ -228,7 +233,7 @@ class MambaMixer(MambaBase, PluggableLayer):
torch.ops.vllm.mamba_mixer(
hidden_states,
output,
self.prefix,
_encode_layer_name(self.prefix),
)
def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
......@@ -515,8 +520,9 @@ def split_batch_to_prefill_and_decode(
def mamba_mixer(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_impl(hidden_states=hidden_states, output=output)
......@@ -525,7 +531,7 @@ def mamba_mixer(
def mamba_mixer_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
return
......
......@@ -44,7 +44,12 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
......@@ -536,7 +541,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
torch.ops.vllm.mamba_mixer2(
projected_states,
ssm_output,
self.prefix,
_encode_layer_name(self.prefix),
)
# 4. gated MLP
......@@ -927,8 +932,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
def mamba_mixer2(
projected_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.conv_ssm_forward(projected_states=projected_states, output=output)
......@@ -937,7 +943,7 @@ def mamba_mixer2(
def mamba_mixer2_fake(
projected_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
layer_name: LayerNameType,
) -> None:
return
......
......@@ -11,7 +11,12 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, has_deep_gemm
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
......@@ -30,7 +35,7 @@ RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
......@@ -46,6 +51,7 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
k_cache_prefix = _resolve_layer_name(k_cache_prefix)
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
......@@ -253,7 +259,7 @@ def sparse_attn_indexer(
def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
......@@ -342,7 +348,7 @@ class SparseAttnIndexer(CustomOp):
):
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
_encode_layer_name(self.k_cache.prefix),
self.k_cache.kv_cache,
q_fp8,
k,
......@@ -366,7 +372,7 @@ class SparseAttnIndexer(CustomOp):
if rocm_aiter_ops.is_enabled():
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
_encode_layer_name(self.k_cache.prefix),
self.k_cache.kv_cache,
q_fp8,
k,
......
......@@ -15,6 +15,7 @@ from packaging import version
from packaging.version import Version
from torch.library import Library, infer_schema
import vllm.envs as envs
from vllm.logger import init_logger
if TYPE_CHECKING:
......@@ -709,37 +710,62 @@ def is_torch_equal(target: str) -> bool:
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev")
# Allow toggling LayerName usage via environment variable.
# Defaults to True on torch >= 2.11, False otherwise.
# Set VLLM_USE_LAYERNAME=0 to disable even on torch >= 2.11.
_USE_LAYERNAME = HAS_OPAQUE_TYPE and envs.VLLM_USE_LAYERNAME
if HAS_OPAQUE_TYPE:
from torch._opaque_base import OpaqueBase
else:
OpaqueBase = object # type: ignore[misc, assignment]
class ModuleName(OpaqueBase): # type: ignore[misc]
class LayerName(OpaqueBase): # type: ignore[misc]
"""Wraps a module name string for use as a torch opaque type.
When torch >= 2.11, this is registered as a hoisted value-type opaque
object so that torch.compile lifts it as a graph input instead of baking
it as a constant. This avoids per-layer recompilation for MOE ops.
it as a constant. This avoids per-layer recompilation for custom ops
that accept layer name strings (attention, MOE, KV cache, etc.).
"""
def __init__(self, value: str):
self.value = value
def __eq__(self, other):
return isinstance(other, ModuleName) and self.value == other.value
return isinstance(other, LayerName) and self.value == other.value
def __hash__(self):
return hash(self.value)
def __fx_repr__(self):
return (f"ModuleName({self.value!r})", {ModuleName})
return (f"LayerName({self.value!r})", {"LayerName": LayerName})
if HAS_OPAQUE_TYPE:
from torch._library.opaque_object import register_opaque_type
register_opaque_type(ModuleName, typ="value", hoist=True)
register_opaque_type(LayerName, typ="value", hoist=True)
# On torch >= 2.11 (with VLLM_USE_LAYERNAME enabled), custom op
# layer_name parameters use LayerName; otherwise they remain plain str.
if TYPE_CHECKING:
from typing import TypeAlias
LayerNameType: TypeAlias = str | LayerName
else:
LayerNameType = LayerName if _USE_LAYERNAME else str
def _resolve_layer_name(layer_name: str | LayerName) -> str:
"""Unwrap a LayerName to str, or return str unchanged."""
return layer_name.value if isinstance(layer_name, LayerName) else layer_name
def _encode_layer_name(layer_name: str) -> str | LayerName:
"""Wrap a str layer name as LayerName when enabled."""
return LayerName(layer_name) if _USE_LAYERNAME else layer_name
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
......
......@@ -9,6 +9,7 @@ import torch
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import LayerNameType
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
......@@ -459,7 +460,7 @@ def rocm_fp8_mqa_logits(
def rocm_aiter_sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
......@@ -486,7 +487,7 @@ def rocm_aiter_sparse_attn_indexer_fake(
def rocm_aiter_sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
......@@ -502,6 +503,9 @@ def rocm_aiter_sparse_attn_indexer(
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
from vllm.utils.torch_utils import _resolve_layer_name
k_cache_prefix = _resolve_layer_name(k_cache_prefix)
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return rocm_aiter_sparse_attn_indexer_fake(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment