Unverified Commit f44afef6 authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[compile] Allow strings in custom ops without regressing compilation times (#38123)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent 447ce222
...@@ -28,6 +28,7 @@ from vllm.forward_context import get_forward_context, set_forward_context ...@@ -28,6 +28,7 @@ from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import _encode_layer_name
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
CommonAttentionMetadata, CommonAttentionMetadata,
...@@ -170,7 +171,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module): ...@@ -170,7 +171,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
k = k.view(-1, self.num_kv_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size) v = v.view(-1, self.num_kv_heads, self.head_size)
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
k, v, self.layer_name k, v, _encode_layer_name(self.layer_name)
) )
return q, k, v, kv_cache_dummy_dep return q, k, v, kv_cache_dummy_dep
......
...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import _USE_LAYERNAME, _encode_layer_name
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
from .matcher_utils import MatcherQuantFP8 from .matcher_utils import MatcherQuantFP8
...@@ -53,21 +54,43 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -53,21 +54,43 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@property @property
def pattern(self) -> Callable[..., torch.Tensor]: def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern( # When _USE_LAYERNAME is enabled (torch >= 2.11), layer_name is
q: torch.Tensor, # passed as an explicit pattern input so the pattern matcher
k: torch.Tensor, # treats it as a wildcard matching hoisted LayerName placeholders.
v: torch.Tensor, # Otherwise it stays as a closure constant (original behavior).
output_attn: torch.Tensor, _ln = _encode_layer_name(self._layer_name)
scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor, if _USE_LAYERNAME:
) -> torch.Tensor:
def _pattern_with_ln( # type: ignore[misc]
q, k, v, output_attn, scale, kv_cache_dummy_dep, layer_name
):
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self._num_heads * self._head_size]
)
return self._quant_matcher(attn_out_view, scale)[0]
return _pattern_with_ln
def _pattern(q, k, v, output_attn, scale, kv_cache_dummy_dep):
at1 = auto_functionalized( at1 = auto_functionalized(
ATTN_OP, ATTN_OP,
query=q, query=q,
key=k, key=k,
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self._layer_name, layer_name=_ln,
output_scale=None, output_scale=None,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
...@@ -81,14 +104,34 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -81,14 +104,34 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@property @property
def replacement(self) -> Callable[..., torch.Tensor]: def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement( _ln = _encode_layer_name(self._layer_name)
q: torch.Tensor,
k: torch.Tensor, if _USE_LAYERNAME:
v: torch.Tensor,
output_attn: torch.Tensor, def _replacement_with_ln( # type: ignore[misc]
scale: torch.Tensor, q, k, v, output_attn, scale, kv_cache_dummy_dep, layer_name
kv_cache_dummy_dep: torch.Tensor, ):
) -> torch.Tensor: output_attn = torch.empty(
[q.shape[0], self._num_heads, self._head_size],
dtype=FP8_DTYPE,
device=q.device,
)
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return RESHAPE_OP(at1[1], [-1, self._num_heads * self._head_size])
return _replacement_with_ln
def _replacement(q, k, v, output_attn, scale, kv_cache_dummy_dep):
output_attn = torch.empty( output_attn = torch.empty(
[q.shape[0], self._num_heads, self._head_size], [q.shape[0], self._num_heads, self._head_size],
dtype=FP8_DTYPE, dtype=FP8_DTYPE,
...@@ -100,7 +143,7 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -100,7 +143,7 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
key=k, key=k,
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self._layer_name, layer_name=_ln,
output_scale=scale, output_scale=scale,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
...@@ -113,7 +156,7 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -113,7 +156,7 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
dtype = self._dtype dtype = self._dtype
num_heads = self._num_heads num_heads = self._num_heads
head_size = self._head_size head_size = self._head_size
return [ inputs: list = [
self.empty(5, num_heads, head_size, dtype=dtype), # q self.empty(5, num_heads, head_size, dtype=dtype), # q
self.empty(5, num_heads, head_size, dtype=dtype), # k self.empty(5, num_heads, head_size, dtype=dtype), # k
self.empty(5, num_heads, head_size, dtype=dtype), # v self.empty(5, num_heads, head_size, dtype=dtype), # v
...@@ -121,6 +164,9 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -121,6 +164,9 @@ class AttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
self.empty_fp32(1, 1), # scale self.empty_fp32(1, 1), # scale
self.empty(0, dtype=dtype), # kv_cache_dummy_dep self.empty(0, dtype=dtype), # kv_cache_dummy_dep
] ]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class AttnNvfp4QuantPattern( class AttnNvfp4QuantPattern(
...@@ -144,23 +190,64 @@ class AttnNvfp4QuantPattern( ...@@ -144,23 +190,64 @@ class AttnNvfp4QuantPattern(
@property @property
def pattern(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: def pattern(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _pattern_with_ln( # type: ignore[misc]
q,
k,
v,
output_attn,
output_quant,
output_scale,
input_scale,
kv_cache_dummy_dep,
layer_name,
):
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self._num_heads * self._head_size]
)
at2 = auto_functionalized(
self._QUANT_OP,
input=attn_out_view,
input_scale=input_scale,
is_sf_swizzled_layout=True,
output=output_quant,
output_scale=output_scale,
)
return at2[1], torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return _pattern_with_ln
def _pattern( def _pattern(
q: torch.Tensor, q,
k: torch.Tensor, k,
v: torch.Tensor, v,
output_attn: torch.Tensor, output_attn,
output_quant: torch.Tensor, output_quant,
output_scale: torch.Tensor, output_scale,
input_scale: torch.Tensor, input_scale,
kv_cache_dummy_dep: torch.Tensor, kv_cache_dummy_dep,
) -> tuple[torch.Tensor, torch.Tensor]: ):
at1 = auto_functionalized( at1 = auto_functionalized(
ATTN_OP, ATTN_OP,
query=q, query=q,
key=k, key=k,
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self._layer_name, layer_name=_ln,
output_scale=None, output_scale=None,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
...@@ -176,42 +263,80 @@ class AttnNvfp4QuantPattern( ...@@ -176,42 +263,80 @@ class AttnNvfp4QuantPattern(
output=output_quant, output=output_quant,
output_scale=output_scale, output_scale=output_scale,
) )
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) return at2[1], torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
return _pattern return _pattern
@property @property
def replacement(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: def replacement(self) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _replacement_with_ln( # type: ignore[misc]
q,
k,
v,
output_attn,
_output_quant,
output_scale,
input_scale,
kv_cache_dummy_dep,
layer_name,
):
output_attn = torch.empty(
[q.shape[0], self._num_heads, self._head_size // 2],
dtype=FP4_DTYPE,
device=q.device,
)
osv = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=layer_name,
output_scale=input_scale,
output_block_scale=osv,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return RESHAPE_OP(
at2[1], [-1, self._num_heads * self._head_size // 2]
), at2[2]
return _replacement_with_ln
def _replacement( def _replacement(
q: torch.Tensor, q,
k: torch.Tensor, k,
v: torch.Tensor, v,
output_attn: torch.Tensor, output_attn,
_output_quant: torch.Tensor, _output_quant,
output_scale: torch.Tensor, output_scale,
input_scale: torch.Tensor, input_scale,
kv_cache_dummy_dep: torch.Tensor, kv_cache_dummy_dep,
) -> tuple[torch.Tensor, torch.Tensor]: ):
output_attn = torch.empty( output_attn = torch.empty(
[q.shape[0], self._num_heads, self._head_size // 2], [q.shape[0], self._num_heads, self._head_size // 2],
dtype=FP4_DTYPE, dtype=FP4_DTYPE,
device=q.device, device=q.device,
) )
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE) osv = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized( at2 = auto_functionalized(
ATTN_OP, ATTN_OP,
query=q, query=q,
key=k, key=k,
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self._layer_name, layer_name=_ln,
output_scale=input_scale, output_scale=input_scale,
output_block_scale=output_scale_view, output_block_scale=osv,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
output = RESHAPE_OP(at2[1], [-1, self._num_heads * self._head_size // 2]) return RESHAPE_OP(
return output, at2[2] at2[1], [-1, self._num_heads * self._head_size // 2]
), at2[2]
return _replacement return _replacement
...@@ -219,18 +344,19 @@ class AttnNvfp4QuantPattern( ...@@ -219,18 +344,19 @@ class AttnNvfp4QuantPattern(
dtype = self._dtype dtype = self._dtype
num_heads = self._num_heads num_heads = self._num_heads
head_size = self._head_size head_size = self._head_size
return [ inputs: list = [
self.empty_bf16(5, num_heads, head_size), # q self.empty_bf16(5, num_heads, head_size), # q
self.empty_bf16(5, num_heads, head_size), # k self.empty_bf16(5, num_heads, head_size), # k
self.empty_bf16(5, num_heads, head_size), # v self.empty_bf16(5, num_heads, head_size), # v
self.empty_bf16(5, num_heads, head_size), # output_attn self.empty_bf16(5, num_heads, head_size), # output_attn
self.empty(5, num_heads * head_size // 2, dtype=FP4_DTYPE), # output_quant self.empty(5, num_heads * head_size // 2, dtype=FP4_DTYPE),
self.empty_i32( self.empty_i32(128, round_up(num_heads * head_size // 16, 4)),
128, round_up(num_heads * head_size // 16, 4)
), # output_scale
self.empty_fp32(1, 1), # input_scale self.empty_fp32(1, 1), # input_scale
self.empty(0, dtype=dtype), # kv_cache_dummy_dep self.empty(0, dtype=dtype), # kv_cache_dummy_dep
] ]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class AttnQuantFusionPass(VllmFusionPatternMatcherPass): class AttnQuantFusionPass(VllmFusionPatternMatcherPass):
...@@ -259,13 +385,19 @@ class AttnQuantFusionPass(VllmFusionPatternMatcherPass): ...@@ -259,13 +385,19 @@ class AttnQuantFusionPass(VllmFusionPatternMatcherPass):
"so no fusion patterns were registered." "so no fusion patterns were registered."
) )
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for layer in layers: for layer in layers:
if layer.impl.fused_output_quant_supported(_FP8_QUANT_KEY): if layer.impl.fused_output_quant_supported(_FP8_QUANT_KEY):
self.register(AttnFp8StaticQuantPattern(layer, dtype)) self.register(AttnFp8StaticQuantPattern(layer, dtype))
if _USE_LAYERNAME:
break
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
for layer in layers: for layer in layers:
if layer.impl.fused_output_quant_supported(kNvfp4Dynamic): if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
self.register(AttnNvfp4QuantPattern(layer, dtype)) self.register(AttnNvfp4QuantPattern(layer, dtype))
if _USE_LAYERNAME:
break
self.dump_patterns(config, self.pm_pass) self.dump_patterns(config, self.pm_pass)
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Dynamic, kNvfp4Dynamic,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import _USE_LAYERNAME, _encode_layer_name
from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement from ..vllm_inductor_pass import VllmFusionPatternMatcherPass, VllmPatternReplacement
from .matcher_utils import MatcherQuantFP8 from .matcher_utils import MatcherQuantFP8
...@@ -49,21 +50,43 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -49,21 +50,43 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@property @property
def pattern(self) -> Callable[..., torch.Tensor]: def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern( _ln = _encode_layer_name(self._layer_name)
q: torch.Tensor,
kv_c_normed: torch.Tensor, if _USE_LAYERNAME:
k_pe: torch.Tensor,
output_attn: torch.Tensor, def _pattern_with_ln( # type: ignore[misc]
scale: torch.Tensor, q,
kv_cache_dummy_dep: torch.Tensor, kv_c_normed,
) -> torch.Tensor: k_pe,
output_attn,
scale,
kv_cache_dummy_dep,
layer_name,
):
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
# MLA output is already 2D (T, N*V), no reshape needed
return self._quant_matcher(at1[1], scale)[0]
return _pattern_with_ln
def _pattern(q, kv_c_normed, k_pe, output_attn, scale, kv_cache_dummy_dep):
at1 = auto_functionalized( at1 = auto_functionalized(
MLA_ATTN_OP, MLA_ATTN_OP,
q=q, q=q,
kv_c_normed=kv_c_normed, kv_c_normed=kv_c_normed,
k_pe=k_pe, k_pe=k_pe,
output=output_attn, output=output_attn,
layer_name=self._layer_name, layer_name=_ln,
output_scale=None, output_scale=None,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
...@@ -75,14 +98,41 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -75,14 +98,41 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
@property @property
def replacement(self) -> Callable[..., torch.Tensor]: def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement( _ln = _encode_layer_name(self._layer_name)
q: torch.Tensor,
kv_c_normed: torch.Tensor, if _USE_LAYERNAME:
k_pe: torch.Tensor,
output_attn: torch.Tensor, def _replacement_with_ln( # type: ignore[misc]
scale: torch.Tensor, q,
kv_cache_dummy_dep: torch.Tensor, kv_c_normed,
) -> torch.Tensor: k_pe,
output_attn,
scale,
kv_cache_dummy_dep,
layer_name,
):
# MLA output in quant_dtype
output_attn = torch.empty(
[q.shape[0], self._output_dim],
dtype=FP8_DTYPE,
device=q.device,
)
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=scale,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return at1[1]
return _replacement_with_ln
def _replacement(q, kv_c_normed, k_pe, output_attn, scale, kv_cache_dummy_dep):
# MLA output in quant_dtype # MLA output in quant_dtype
output_attn = torch.empty( output_attn = torch.empty(
[q.shape[0], self._output_dim], [q.shape[0], self._output_dim],
...@@ -95,7 +145,7 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -95,7 +145,7 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
kv_c_normed=kv_c_normed, kv_c_normed=kv_c_normed,
k_pe=k_pe, k_pe=k_pe,
output=output_attn, output=output_attn,
layer_name=self._layer_name, layer_name=_ln,
output_scale=scale, output_scale=scale,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
...@@ -105,7 +155,7 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -105,7 +155,7 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
return _replacement return _replacement
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
return [ inputs: list = [
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype), self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
self.empty(5, self._kv_lora_rank, dtype=self._dtype), self.empty(5, self._kv_lora_rank, dtype=self._dtype),
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype), self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
...@@ -113,6 +163,9 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]): ...@@ -113,6 +163,9 @@ class MLAAttnFp8StaticQuantPattern(VllmPatternReplacement[..., torch.Tensor]):
self.empty_fp32(1, 1), self.empty_fp32(1, 1),
self.empty(0, dtype=self._dtype), self.empty(0, dtype=self._dtype),
] ]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class MLAAttnNvfp4QuantPattern( class MLAAttnNvfp4QuantPattern(
...@@ -141,21 +194,56 @@ class MLAAttnNvfp4QuantPattern( ...@@ -141,21 +194,56 @@ class MLAAttnNvfp4QuantPattern(
def pattern( def pattern(
self, self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _pattern_with_ln( # type: ignore[misc]
q,
kv_c_normed,
k_pe,
output_attn,
input_scale,
kv_cache_dummy_dep,
layer_name,
):
at1 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=None,
output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output_quant, output_scale = create_fp4_output_tensors(
at1[1].shape[0], at1[1].shape[1], at1[1].device, True
)
at2 = auto_functionalized(
self._QUANT_OP,
input=at1[1],
input_scale=input_scale,
is_sf_swizzled_layout=True,
output=output_quant,
output_scale=output_scale,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
return _pattern_with_ln
def _pattern( def _pattern(
q: torch.Tensor, q, kv_c_normed, k_pe, output_attn, input_scale, kv_cache_dummy_dep
kv_c_normed: torch.Tensor, ):
k_pe: torch.Tensor,
output_attn: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = auto_functionalized( at1 = auto_functionalized(
MLA_ATTN_OP, MLA_ATTN_OP,
q=q, q=q,
kv_c_normed=kv_c_normed, kv_c_normed=kv_c_normed,
k_pe=k_pe, k_pe=k_pe,
output=output_attn, output=output_attn,
layer_name=self._layer_name, layer_name=_ln,
output_scale=None, output_scale=None,
output_block_scale=None, output_block_scale=None,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
...@@ -182,14 +270,47 @@ class MLAAttnNvfp4QuantPattern( ...@@ -182,14 +270,47 @@ class MLAAttnNvfp4QuantPattern(
def replacement( def replacement(
self, self,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
_ln = _encode_layer_name(self._layer_name)
if _USE_LAYERNAME:
def _replacement_with_ln( # type: ignore[misc]
q,
kv_c_normed,
k_pe,
output_attn,
input_scale,
kv_cache_dummy_dep,
layer_name,
):
# MLA output in quant_dtype (FP4 packed as uint8)
output_attn = torch.empty(
[q.shape[0], self._output_dim // 2],
dtype=FP4_DTYPE,
device=q.device,
)
output_scale = create_fp4_output_tensors(
q.shape[0], self._output_dim, q.device, True
)[1]
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
MLA_ATTN_OP,
q=q,
kv_c_normed=kv_c_normed,
k_pe=k_pe,
output=output_attn,
layer_name=layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return at2[1], at2[2]
return _replacement_with_ln
def _replacement( def _replacement(
q: torch.Tensor, q, kv_c_normed, k_pe, output_attn, input_scale, kv_cache_dummy_dep
kv_c_normed: torch.Tensor, ):
k_pe: torch.Tensor,
output_attn: torch.Tensor,
input_scale: torch.Tensor,
kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# MLA output in quant_dtype (FP4 packed as uint8) # MLA output in quant_dtype (FP4 packed as uint8)
output_attn = torch.empty( output_attn = torch.empty(
[q.shape[0], self._output_dim // 2], [q.shape[0], self._output_dim // 2],
...@@ -207,7 +328,7 @@ class MLAAttnNvfp4QuantPattern( ...@@ -207,7 +328,7 @@ class MLAAttnNvfp4QuantPattern(
kv_c_normed=kv_c_normed, kv_c_normed=kv_c_normed,
k_pe=k_pe, k_pe=k_pe,
output=output_attn, output=output_attn,
layer_name=self._layer_name, layer_name=_ln,
output_scale=input_scale, output_scale=input_scale,
output_block_scale=output_scale_view, output_block_scale=output_scale_view,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
...@@ -217,7 +338,7 @@ class MLAAttnNvfp4QuantPattern( ...@@ -217,7 +338,7 @@ class MLAAttnNvfp4QuantPattern(
return _replacement return _replacement
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
return [ inputs: list = [
self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype), self.empty(5, self._num_heads, self._qk_head_dim, dtype=self._dtype),
self.empty(5, self._kv_lora_rank, dtype=self._dtype), self.empty(5, self._kv_lora_rank, dtype=self._dtype),
self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype), self.empty(5, 1, self._qk_rope_head_dim, dtype=self._dtype),
...@@ -225,6 +346,9 @@ class MLAAttnNvfp4QuantPattern( ...@@ -225,6 +346,9 @@ class MLAAttnNvfp4QuantPattern(
self.empty_fp32(1, 1), self.empty_fp32(1, 1),
self.empty(0, dtype=self._dtype), self.empty(0, dtype=self._dtype),
] ]
if _USE_LAYERNAME:
inputs.append(_encode_layer_name(self._layer_name))
return inputs
class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass): class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
...@@ -250,13 +374,19 @@ class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass): ...@@ -250,13 +374,19 @@ class MLAAttnQuantFusionPass(VllmFusionPatternMatcherPass):
"so no fusion patterns were registered." "so no fusion patterns were registered."
) )
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for layer in layers: for layer in layers:
if layer.impl.fused_output_quant_supported(kFp8StaticTensorSym): if layer.impl.fused_output_quant_supported(kFp8StaticTensorSym):
self.register(MLAAttnFp8StaticQuantPattern(layer, dtype)) self.register(MLAAttnFp8StaticQuantPattern(layer, dtype))
if _USE_LAYERNAME:
break
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
for layer in layers: for layer in layers:
if layer.impl.fused_output_quant_supported(kNvfp4Dynamic): if layer.impl.fused_output_quant_supported(kNvfp4Dynamic):
self.register(MLAAttnNvfp4QuantPattern(layer, dtype)) self.register(MLAAttnNvfp4QuantPattern(layer, dtype))
if _USE_LAYERNAME:
break
self.dump_patterns(config, self.pm_pass) self.dump_patterns(config, self.pm_pass)
...@@ -15,7 +15,13 @@ from vllm.model_executor.layers.attention.attention import ( ...@@ -15,7 +15,13 @@ from vllm.model_executor.layers.attention.attention import (
Attention, Attention,
get_attention_context, get_attention_context,
) )
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import (
_USE_LAYERNAME,
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from ..inductor_pass import enable_fake_mode from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
...@@ -37,7 +43,7 @@ def fused_rope_and_unified_kv_cache_update_impl( ...@@ -37,7 +43,7 @@ def fused_rope_and_unified_kv_cache_update_impl(
positions: torch.Tensor, positions: torch.Tensor,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox: bool, is_neox: bool,
layer_name: str = "", layer_name: LayerNameType,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This impl fetches the KV cache and slot mapping from the forward context, This impl fetches the KV cache and slot mapping from the forward context,
...@@ -46,6 +52,7 @@ def fused_rope_and_unified_kv_cache_update_impl( ...@@ -46,6 +52,7 @@ def fused_rope_and_unified_kv_cache_update_impl(
that is passed to unified_attention to signal a side effect and that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering. the data dependency between them to ensure torch.compile preserves ordering.
""" """
layer_name = _resolve_layer_name(layer_name)
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name) _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None: if layer_slot_mapping is not None:
attn_layer.impl.do_rope_and_kv_cache_update( attn_layer.impl.do_rope_and_kv_cache_update(
...@@ -70,7 +77,7 @@ def fused_rope_and_unified_kv_cache_update_fake( ...@@ -70,7 +77,7 @@ def fused_rope_and_unified_kv_cache_update_fake(
positions: torch.Tensor, positions: torch.Tensor,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox: bool, is_neox: bool,
layer_name: str = "", layer_name: LayerNameType,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty(0, device=query.device, dtype=query.dtype) return torch.empty(0, device=query.device, dtype=query.dtype)
...@@ -120,38 +127,30 @@ class RopeReshapeKVCachePattern: ...@@ -120,38 +127,30 @@ class RopeReshapeKVCachePattern:
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list:
# Sample inputs to help pattern tracing # Sample inputs to help pattern tracing
T = 5 T = 5
L = 4096 L = 4096
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size) qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
positions = empty_i64(T) positions = empty_i64(T)
cos_sin_cache = empty_bf16(L, self.head_size) cos_sin_cache = empty_bf16(L, self.head_size)
return [ inputs: list = [qkv, positions, cos_sin_cache]
qkv, if _USE_LAYERNAME:
positions, inputs.append(_encode_layer_name(self.layer_name))
cos_sin_cache, return inputs
]
def register(self, pm_pass: PatternMatcherPass) -> None: def _mk_pattern_with_layer_name_input(self, _ln):
def pattern( """Pattern/replacement with layer_name as an explicit input."""
qkv: torch.Tensor,
positions: torch.Tensor, def pattern(qkv, positions, cos_sin_cache, layer_name):
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rope_matcher(positions, q, k, cos_sin_cache) q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
q = q.view(-1, self.num_heads, self.head_size) q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v) v = v.view(-1, self.num_kv_heads, self.head_size_v)
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name) return torch.ops.vllm.unified_kv_cache_update(k, v, layer_name), q, k, v
return dummy, q, k, v
def replacement(qkv, positions, cos_sin_cache, layer_name):
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_size) q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size)
...@@ -164,10 +163,50 @@ class RopeReshapeKVCachePattern: ...@@ -164,10 +163,50 @@ class RopeReshapeKVCachePattern:
positions=positions, positions=positions,
cos_sin_cache=cos_sin_cache, cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox, is_neox=self.is_neox,
layer_name=self.layer_name, layer_name=layer_name,
) )
return results[0], results[1], results[2], v return results[0], results[1], results[2], v
return pattern, replacement
def _mk_pattern_with_layer_name_closure(self, _ln):
"""Pattern/replacement with layer_name as a closure constant."""
def pattern(qkv, positions, cos_sin_cache):
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
return torch.ops.vllm.unified_kv_cache_update(k, v, _ln), q, k, v
def replacement(qkv, positions, cos_sin_cache):
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
results = auto_functionalized(
self.FUSED_OP,
query=q,
key=k,
value=v,
positions=positions,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
layer_name=_ln,
)
return results[0], results[1], results[2], v
return pattern, replacement
def register(self, pm_pass: PatternMatcherPass) -> None:
_ln = _encode_layer_name(self.layer_name)
if _USE_LAYERNAME:
pattern, replacement = self._mk_pattern_with_layer_name_input(_ln)
else:
pattern, replacement = self._mk_pattern_with_layer_name_closure(_ln)
# NOTE: use view_to_reshape to unify view/reshape to simplify # NOTE: use view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities # pattern and increase matching opportunities
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule: def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
...@@ -176,7 +215,11 @@ class RopeReshapeKVCachePattern: ...@@ -176,7 +215,11 @@ class RopeReshapeKVCachePattern:
return gm return gm
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass pattern,
replacement,
self.get_inputs(),
fwd_and_view_to_reshape,
pm_pass,
) )
...@@ -205,6 +248,8 @@ class RopeKVCacheFusionPass(VllmPatternMatcherPass): ...@@ -205,6 +248,8 @@ class RopeKVCacheFusionPass(VllmPatternMatcherPass):
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
attn_layers = get_layers_from_vllm_config(config, Attention) attn_layers = get_layers_from_vllm_config(config, Attention)
# When _USE_LAYERNAME is enabled, layer_name is a wildcard so all
# layers produce the same pattern — register once then break.
for _, layer in attn_layers.items(): for _, layer in attn_layers.items():
if layer.impl.fused_rope_kvcache_supported(): if layer.impl.fused_rope_kvcache_supported():
for is_neox in [True, False]: for is_neox in [True, False]:
...@@ -212,6 +257,8 @@ class RopeKVCacheFusionPass(VllmPatternMatcherPass): ...@@ -212,6 +257,8 @@ class RopeKVCacheFusionPass(VllmPatternMatcherPass):
layer=layer, layer=layer,
is_neox=is_neox, is_neox=is_neox,
).register(self.patterns) ).register(self.patterns)
if _USE_LAYERNAME:
break
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
......
...@@ -129,6 +129,7 @@ if TYPE_CHECKING: ...@@ -129,6 +129,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
VLLM_USE_LAYERNAME: bool = True
Q_SCALE_CONSTANT: int = 200 Q_SCALE_CONSTANT: int = 200
K_SCALE_CONSTANT: int = 200 K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100 V_SCALE_CONSTANT: int = 100
...@@ -1090,6 +1091,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1090,6 +1091,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1") os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")
), ),
"VLLM_DISABLE_COMPILE_CACHE": disable_compile_cache, "VLLM_DISABLE_COMPILE_CACHE": disable_compile_cache,
# If set to "0", disable LayerName opaque type for layer_name
# parameters in custom ops. Defaults to enabled on torch >= 2.11.
"VLLM_USE_LAYERNAME": lambda: bool(int(os.getenv("VLLM_USE_LAYERNAME", "1"))),
# If set, vllm will run in development mode, which will enable # If set, vllm will run in development mode, which will enable
# some additional endpoints for developing and debugging, # some additional endpoints for developing and debugging,
# e.g. `/reset_prefix_cache` # e.g. `/reset_prefix_cache`
......
...@@ -25,6 +25,9 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod ...@@ -25,6 +25,9 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op, direct_register_custom_op,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )
...@@ -414,7 +417,9 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -414,7 +417,9 @@ class Attention(nn.Module, AttentionLayerBase):
`vllm.forward_context.get_forward_context().attn_metadata`. `vllm.forward_context.get_forward_context().attn_metadata`.
""" """
if self.calculate_kv_scales: if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) torch.ops.vllm.maybe_calc_kv_scales(
query, key, value, _encode_layer_name(self.layer_name)
)
output_dtype = query.dtype output_dtype = query.dtype
if self.query_quant is not None: if self.query_quant is not None:
# quantizing with a simple torch operation enables # quantizing with a simple torch operation enables
...@@ -466,6 +471,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -466,6 +471,7 @@ class Attention(nn.Module, AttentionLayerBase):
) )
else: else:
# Skip this if sharing KV cache with an earlier attention layer. # Skip this if sharing KV cache with an earlier attention layer.
encoded = _encode_layer_name(self.layer_name)
if ( if (
not self.attn_backend.forward_includes_kv_cache_update not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None and self.kv_sharing_target_layer_name is None
...@@ -473,14 +479,14 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -473,14 +479,14 @@ class Attention(nn.Module, AttentionLayerBase):
and value is not None and value is not None
): ):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name key, value, encoded
) )
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, query,
key, key,
value, value,
output, output,
self.layer_name, encoded,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
...@@ -553,8 +559,9 @@ def maybe_calc_kv_scales( ...@@ -553,8 +559,9 @@ def maybe_calc_kv_scales(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
...@@ -570,7 +577,7 @@ def maybe_calc_kv_scales_fake( ...@@ -570,7 +577,7 @@ def maybe_calc_kv_scales_fake(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
return return
...@@ -622,12 +629,13 @@ def get_attention_context( ...@@ -622,12 +629,13 @@ def get_attention_context(
def unified_kv_cache_update( def unified_kv_cache_update(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Returns a dummy that is passed to unified_attention to signal a side effect and Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering. the data dependency between them to ensure torch.compile preserves ordering.
""" """
layer_name = _resolve_layer_name(layer_name)
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name) _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None: if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), ( assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
...@@ -647,7 +655,7 @@ def unified_kv_cache_update( ...@@ -647,7 +655,7 @@ def unified_kv_cache_update(
def unified_kv_cache_update_fake( def unified_kv_cache_update_fake(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty(0, device=key.device, dtype=key.dtype) return torch.empty(0, device=key.device, dtype=key.dtype)
...@@ -666,7 +674,7 @@ def unified_attention_with_output( ...@@ -666,7 +674,7 @@ def unified_attention_with_output(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None,
...@@ -675,6 +683,7 @@ def unified_attention_with_output( ...@@ -675,6 +683,7 @@ def unified_attention_with_output(
# that ensures torch.compile preserves ordering between KV cache update and # that ensures torch.compile preserves ordering between KV cache update and
# attention forward. # attention forward.
del kv_cache_dummy_dep del kv_cache_dummy_dep
layer_name = _resolve_layer_name(layer_name)
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name) attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
self.impl.forward( self.impl.forward(
...@@ -695,7 +704,7 @@ def unified_attention_with_output_fake( ...@@ -695,7 +704,7 @@ def unified_attention_with_output_fake(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None,
......
...@@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import ( ...@@ -9,6 +9,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group, is_v1_kv_transfer_group,
) )
from vllm.utils.torch_utils import _resolve_layer_name
def maybe_transfer_kv_layer(func: Callable) -> Callable: def maybe_transfer_kv_layer(func: Callable) -> Callable:
...@@ -38,7 +39,7 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable: ...@@ -38,7 +39,7 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return func(*args, **kwargs) return func(*args, **kwargs)
layer_name: str = args[layer_name_index] layer_name = _resolve_layer_name(args[layer_name_index])
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping) # Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name) attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
......
...@@ -240,6 +240,9 @@ from vllm.platforms import current_platform ...@@ -240,6 +240,9 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op, direct_register_custom_op,
is_quantized_kv_cache, is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
...@@ -473,7 +476,12 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -473,7 +476,12 @@ class MLAAttention(nn.Module, AttentionLayerBase):
output_shape: torch.Size | None = None, output_shape: torch.Size | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.calculate_kv_scales: if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) torch.ops.vllm.maybe_calc_kv_scales(
q,
kv_c_normed,
k_pe,
_encode_layer_name(self.layer_name),
)
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
...@@ -505,10 +513,11 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -505,10 +513,11 @@ class MLAAttention(nn.Module, AttentionLayerBase):
) )
return output return output
else: else:
encoded = _encode_layer_name(self.layer_name)
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
kv_c_normed, kv_c_normed,
k_pe, k_pe,
self.layer_name, encoded,
self.kv_cache_dtype, self.kv_cache_dtype,
self._k_scale, self._k_scale,
) )
...@@ -518,7 +527,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -518,7 +527,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_c_normed, kv_c_normed,
k_pe, k_pe,
output, output,
self.layer_name, encoded,
kv_cache_dummy_dep=kv_cache_dummy_dep, kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
return output return output
...@@ -900,7 +909,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -900,7 +909,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
def unified_mla_kv_cache_update( def unified_mla_kv_cache_update(
kv_c_normed: torch.Tensor, kv_c_normed: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: torch.Tensor, k_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -908,6 +917,7 @@ def unified_mla_kv_cache_update( ...@@ -908,6 +917,7 @@ def unified_mla_kv_cache_update(
Returns a dummy that is passed to unified_attention to signal a side effect and Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering. the data dependency between them to ensure torch.compile preserves ordering.
""" """
layer_name = _resolve_layer_name(layer_name)
forward_context = get_forward_context() forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name] attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache kv_cache = attn_layer.kv_cache
...@@ -939,7 +949,7 @@ def unified_mla_kv_cache_update( ...@@ -939,7 +949,7 @@ def unified_mla_kv_cache_update(
def unified_mla_kv_cache_update_fake( def unified_mla_kv_cache_update_fake(
kv_c_normed: torch.Tensor, kv_c_normed: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: torch.Tensor, k_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -959,7 +969,7 @@ def unified_mla_attention_with_output( ...@@ -959,7 +969,7 @@ def unified_mla_attention_with_output(
kv_c_normed: torch.Tensor, kv_c_normed: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None,
...@@ -968,6 +978,7 @@ def unified_mla_attention_with_output( ...@@ -968,6 +978,7 @@ def unified_mla_attention_with_output(
# that ensures torch.compile preserves ordering between KV cache update and # that ensures torch.compile preserves ordering between KV cache update and
# attention forward. # attention forward.
del kv_cache_dummy_dep del kv_cache_dummy_dep
layer_name = _resolve_layer_name(layer_name)
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
layer.forward_impl( layer.forward_impl(
q, q,
...@@ -986,7 +997,7 @@ def unified_mla_attention_with_output_fake( ...@@ -986,7 +997,7 @@ def unified_mla_attention_with_output_fake(
kv_c_normed: torch.Tensor, kv_c_normed: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None,
......
...@@ -10,7 +10,12 @@ from vllm.logger import init_logger ...@@ -10,7 +10,12 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionMetadata, AttentionMetadata,
...@@ -170,7 +175,9 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -170,7 +175,9 @@ class StaticSinkAttention(Attention, CustomOp):
) )
if not self.sink_populated: if not self.sink_populated:
self_kv_cache = self.kv_cache self_kv_cache = self.kv_cache
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name) torch.ops.vllm.maybe_populate_sink(
self_kv_cache, _encode_layer_name(self.layer_name)
)
return super().forward(query, key, value, output_shape) return super().forward(query, key, value, output_shape)
...@@ -224,8 +231,9 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -224,8 +231,9 @@ class StaticSinkAttention(Attention, CustomOp):
def maybe_populate_sink( def maybe_populate_sink(
self_kv_cache: torch.Tensor, self_kv_cache: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
if self.sink_populated or self_kv_cache.numel() == 0: if self.sink_populated or self_kv_cache.numel() == 0:
...@@ -235,7 +243,7 @@ def maybe_populate_sink( ...@@ -235,7 +243,7 @@ def maybe_populate_sink(
def maybe_populate_sink_fake( def maybe_populate_sink_fake(
self_kv_cache: torch.Tensor, self_kv_cache: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
return return
......
...@@ -32,15 +32,15 @@ from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( ...@@ -32,15 +32,15 @@ from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
HAS_OPAQUE_TYPE, _USE_LAYERNAME,
ModuleName, LayerName,
direct_register_custom_op, direct_register_custom_op,
) )
def get_layer_from_name(layer_name: str) -> torch.nn.Module: def get_layer_from_name(layer_name: str) -> torch.nn.Module:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
if not HAS_OPAQUE_TYPE and layer_name == "from_forward_context": if not _USE_LAYERNAME and layer_name == "from_forward_context":
all_moe_layers = forward_context.all_moe_layers all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index moe_layer_index = forward_context.moe_layer_index
...@@ -55,21 +55,21 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module: ...@@ -55,21 +55,21 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
return forward_context.no_compile_layers[layer_name] return forward_context.no_compile_layers[layer_name]
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object; # On torch >= 2.11, layer_name is a hoisted LayerName opaque object;
# on older versions it remains a plain str. # on older versions it remains a plain str.
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeAlias from typing import TypeAlias
_layer_name_type: TypeAlias = str | ModuleName _layer_name_type: TypeAlias = str | LayerName
else: else:
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str _layer_name_type = LayerName if _USE_LAYERNAME else str
@torch.compiler.assume_constant_result @torch.compiler.assume_constant_result
def _resolve_layer_name(layer_name: str | ModuleName) -> str: def _resolve_layer_name(layer_name: str | LayerName) -> str:
from torch._library.fake_class_registry import FakeScriptObject from torch._library.fake_class_registry import FakeScriptObject
if isinstance(layer_name, ModuleName): if isinstance(layer_name, LayerName):
return layer_name.value return layer_name.value
elif isinstance(layer_name, FakeScriptObject): elif isinstance(layer_name, FakeScriptObject):
return layer_name.real_obj.value return layer_name.real_obj.value
...@@ -331,9 +331,9 @@ class MoERunnerBase(MoERunner): ...@@ -331,9 +331,9 @@ class MoERunnerBase(MoERunner):
assert len(trunc_sizes) == 1 assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0]) return func(states, trunc_sizes[0])
def _encode_layer_name(self) -> str | ModuleName: def _encode_layer_name(self) -> str | LayerName:
if HAS_OPAQUE_TYPE: if _USE_LAYERNAME:
return ModuleName(self.layer_name) return LayerName(self.layer_name)
# Can be unavailable or None in unittests # Can be unavailable or None in unittests
if ( if (
is_forward_context_available() is_forward_context_available()
......
...@@ -56,7 +56,12 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -56,7 +56,12 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
...@@ -568,7 +573,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -568,7 +573,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
b, b,
a, a,
core_attn_out, core_attn_out,
self.prefix, _encode_layer_name(self.prefix),
) )
# ============================================================ # ============================================================
...@@ -1084,13 +1089,14 @@ def gdn_attention_core( ...@@ -1084,13 +1089,14 @@ def gdn_attention_core(
b: torch.Tensor, b: torch.Tensor,
a: torch.Tensor, a: torch.Tensor,
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
""" """
Custom op for the core attention computation. Custom op for the core attention computation.
Only handles the convolution + recurrent attention part. Only handles the convolution + recurrent attention part.
Input/output projections are handled outside this op. Input/output projections are handled outside this op.
""" """
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self._forward_core( self._forward_core(
...@@ -1106,7 +1112,7 @@ def gdn_attention_core_fake( ...@@ -1106,7 +1112,7 @@ def gdn_attention_core_fake(
b: torch.Tensor, b: torch.Tensor,
a: torch.Tensor, a: torch.Tensor,
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
"""Fake implementation for torch.compile.""" """Fake implementation for torch.compile."""
return return
......
...@@ -36,7 +36,12 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( ...@@ -36,7 +36,12 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
...@@ -228,7 +233,7 @@ class MambaMixer(MambaBase, PluggableLayer): ...@@ -228,7 +233,7 @@ class MambaMixer(MambaBase, PluggableLayer):
torch.ops.vllm.mamba_mixer( torch.ops.vllm.mamba_mixer(
hidden_states, hidden_states,
output, output,
self.prefix, _encode_layer_name(self.prefix),
) )
def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
...@@ -515,8 +520,9 @@ def split_batch_to_prefill_and_decode( ...@@ -515,8 +520,9 @@ def split_batch_to_prefill_and_decode(
def mamba_mixer( def mamba_mixer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.forward_impl(hidden_states=hidden_states, output=output) self.forward_impl(hidden_states=hidden_states, output=output)
...@@ -525,7 +531,7 @@ def mamba_mixer( ...@@ -525,7 +531,7 @@ def mamba_mixer(
def mamba_mixer_fake( def mamba_mixer_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
return return
......
...@@ -44,7 +44,12 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -44,7 +44,12 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
...@@ -536,7 +541,7 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -536,7 +541,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
torch.ops.vllm.mamba_mixer2( torch.ops.vllm.mamba_mixer2(
projected_states, projected_states,
ssm_output, ssm_output,
self.prefix, _encode_layer_name(self.prefix),
) )
# 4. gated MLP # 4. gated MLP
...@@ -927,8 +932,9 @@ class MambaMixer2(MambaBase, PluggableLayer): ...@@ -927,8 +932,9 @@ class MambaMixer2(MambaBase, PluggableLayer):
def mamba_mixer2( def mamba_mixer2(
projected_states: torch.Tensor, projected_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
self.conv_ssm_forward(projected_states=projected_states, output=output) self.conv_ssm_forward(projected_states=projected_states, output=output)
...@@ -937,7 +943,7 @@ def mamba_mixer2( ...@@ -937,7 +943,7 @@ def mamba_mixer2(
def mamba_mixer2_fake( def mamba_mixer2_fake(
projected_states: torch.Tensor, projected_states: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: LayerNameType,
) -> None: ) -> None:
return return
......
...@@ -11,7 +11,12 @@ from vllm.logger import init_logger ...@@ -11,7 +11,12 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, has_deep_gemm from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, has_deep_gemm
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadata,
) )
...@@ -30,7 +35,7 @@ RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 ...@@ -30,7 +35,7 @@ RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
def sparse_attn_indexer( def sparse_attn_indexer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
k_cache_prefix: str, k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -46,6 +51,7 @@ def sparse_attn_indexer( ...@@ -46,6 +51,7 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run # careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
k_cache_prefix = _resolve_layer_name(k_cache_prefix)
# assert isinstance(attn_metadata, dict) # assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict): if not isinstance(attn_metadata, dict):
...@@ -253,7 +259,7 @@ def sparse_attn_indexer( ...@@ -253,7 +259,7 @@ def sparse_attn_indexer(
def sparse_attn_indexer_fake( def sparse_attn_indexer_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
k_cache_prefix: str, k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -342,7 +348,7 @@ class SparseAttnIndexer(CustomOp): ...@@ -342,7 +348,7 @@ class SparseAttnIndexer(CustomOp):
): ):
return torch.ops.vllm.sparse_attn_indexer( return torch.ops.vllm.sparse_attn_indexer(
hidden_states, hidden_states,
self.k_cache.prefix, _encode_layer_name(self.k_cache.prefix),
self.k_cache.kv_cache, self.k_cache.kv_cache,
q_fp8, q_fp8,
k, k,
...@@ -366,7 +372,7 @@ class SparseAttnIndexer(CustomOp): ...@@ -366,7 +372,7 @@ class SparseAttnIndexer(CustomOp):
if rocm_aiter_ops.is_enabled(): if rocm_aiter_ops.is_enabled():
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states, hidden_states,
self.k_cache.prefix, _encode_layer_name(self.k_cache.prefix),
self.k_cache.kv_cache, self.k_cache.kv_cache,
q_fp8, q_fp8,
k, k,
......
...@@ -15,6 +15,7 @@ from packaging import version ...@@ -15,6 +15,7 @@ from packaging import version
from packaging.version import Version from packaging.version import Version
from torch.library import Library, infer_schema from torch.library import Library, infer_schema
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -709,37 +710,62 @@ def is_torch_equal(target: str) -> bool: ...@@ -709,37 +710,62 @@ def is_torch_equal(target: str) -> bool:
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev") HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev")
# Allow toggling LayerName usage via environment variable.
# Defaults to True on torch >= 2.11, False otherwise.
# Set VLLM_USE_LAYERNAME=0 to disable even on torch >= 2.11.
_USE_LAYERNAME = HAS_OPAQUE_TYPE and envs.VLLM_USE_LAYERNAME
if HAS_OPAQUE_TYPE: if HAS_OPAQUE_TYPE:
from torch._opaque_base import OpaqueBase from torch._opaque_base import OpaqueBase
else: else:
OpaqueBase = object # type: ignore[misc, assignment] OpaqueBase = object # type: ignore[misc, assignment]
class ModuleName(OpaqueBase): # type: ignore[misc] class LayerName(OpaqueBase): # type: ignore[misc]
"""Wraps a module name string for use as a torch opaque type. """Wraps a module name string for use as a torch opaque type.
When torch >= 2.11, this is registered as a hoisted value-type opaque When torch >= 2.11, this is registered as a hoisted value-type opaque
object so that torch.compile lifts it as a graph input instead of baking object so that torch.compile lifts it as a graph input instead of baking
it as a constant. This avoids per-layer recompilation for MOE ops. it as a constant. This avoids per-layer recompilation for custom ops
that accept layer name strings (attention, MOE, KV cache, etc.).
""" """
def __init__(self, value: str): def __init__(self, value: str):
self.value = value self.value = value
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, ModuleName) and self.value == other.value return isinstance(other, LayerName) and self.value == other.value
def __hash__(self): def __hash__(self):
return hash(self.value) return hash(self.value)
def __fx_repr__(self): def __fx_repr__(self):
return (f"ModuleName({self.value!r})", {ModuleName}) return (f"LayerName({self.value!r})", {"LayerName": LayerName})
if HAS_OPAQUE_TYPE: if HAS_OPAQUE_TYPE:
from torch._library.opaque_object import register_opaque_type from torch._library.opaque_object import register_opaque_type
register_opaque_type(ModuleName, typ="value", hoist=True) register_opaque_type(LayerName, typ="value", hoist=True)
# On torch >= 2.11 (with VLLM_USE_LAYERNAME enabled), custom op
# layer_name parameters use LayerName; otherwise they remain plain str.
if TYPE_CHECKING:
from typing import TypeAlias
LayerNameType: TypeAlias = str | LayerName
else:
LayerNameType = LayerName if _USE_LAYERNAME else str
def _resolve_layer_name(layer_name: str | LayerName) -> str:
"""Unwrap a LayerName to str, or return str unchanged."""
return layer_name.value if isinstance(layer_name, LayerName) else layer_name
def _encode_layer_name(layer_name: str) -> str | LayerName:
"""Wrap a str layer name as LayerName when enabled."""
return LayerName(layer_name) if _USE_LAYERNAME else layer_name
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform # Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import LayerNameType
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
...@@ -459,7 +460,7 @@ def rocm_fp8_mqa_logits( ...@@ -459,7 +460,7 @@ def rocm_fp8_mqa_logits(
def rocm_aiter_sparse_attn_indexer_fake( def rocm_aiter_sparse_attn_indexer_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
k_cache_prefix: str, k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -486,7 +487,7 @@ def rocm_aiter_sparse_attn_indexer_fake( ...@@ -486,7 +487,7 @@ def rocm_aiter_sparse_attn_indexer_fake(
def rocm_aiter_sparse_attn_indexer( def rocm_aiter_sparse_attn_indexer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
k_cache_prefix: str, k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -502,6 +503,9 @@ def rocm_aiter_sparse_attn_indexer( ...@@ -502,6 +503,9 @@ def rocm_aiter_sparse_attn_indexer(
# careful! this will be None in dummy run # careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
from vllm.utils.torch_utils import _resolve_layer_name
k_cache_prefix = _resolve_layer_name(k_cache_prefix)
# assert isinstance(attn_metadata, dict) # assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict): if not isinstance(attn_metadata, dict):
return rocm_aiter_sparse_attn_indexer_fake( return rocm_aiter_sparse_attn_indexer_fake(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment