# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch._inductor.pattern_matcher as pm from torch import fx from torch._higher_order_ops import auto_functionalized from torch._inductor.fx_passes.post_grad import view_to_reshape from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config.utils import Range from vllm.logger import init_logger from vllm.model_executor.layers.attention.attention import ( Attention, get_attention_context, ) from vllm.utils.torch_utils import ( _USE_LAYERNAME, LayerNameType, _encode_layer_name, _resolve_layer_name, direct_register_custom_op, ) from ..inductor_pass import enable_fake_mode from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import ( MatcherRotaryEmbedding, ) from .rms_quant_fusion import ( empty_bf16, empty_i64, ) logger = init_logger(__name__) def fused_rope_and_unified_kv_cache_update_impl( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, layer_name: LayerNameType, ) -> torch.Tensor: """ This impl fetches the KV cache and slot mapping from the forward context, then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method. It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`, that is passed to unified_attention to signal a side effect and the data dependency between them to ensure torch.compile preserves ordering. """ layer_name = _resolve_layer_name(layer_name) _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name) if layer_slot_mapping is not None: attn_layer.impl.do_rope_and_kv_cache_update( attn_layer, query, key, value, positions, cos_sin_cache, is_neox, kv_cache, layer_slot_mapping, ) return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) def fused_rope_and_unified_kv_cache_update_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, layer_name: LayerNameType, ) -> torch.Tensor: return torch.empty(0, device=query.device, dtype=query.dtype) direct_register_custom_op( op_name="fused_rope_and_unified_kv_cache_update", op_func=fused_rope_and_unified_kv_cache_update_impl, mutates_args=["query", "key"], fake_impl=fused_rope_and_unified_kv_cache_update_fake, ) class RopeReshapeKVCachePattern: """ This pattern matches the following unfused inplace ops: q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox) kv_cache_dummy = unified_kv_cache_update(k, v, layer_name) and replaces it with the fused inplace op: kv_cache_dummy = fused_rope_and_unified_kv_cache_update( q, k, v, positions, cos_sin_cache, is_neox, layer_name ) """ FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default def __init__( self, layer: Attention, is_neox: bool, ) -> None: self.layer_name = layer.layer_name self.num_heads = layer.num_heads self.num_kv_heads = layer.num_kv_heads self.head_size = layer.head_size self.head_size_v = layer.head_size_v self.is_neox = is_neox self.q_size = self.num_heads * self.head_size self.k_size = self.num_kv_heads * self.head_size self.v_size = self.num_kv_heads * self.head_size_v self.rope_matcher = MatcherRotaryEmbedding( is_neox=self.is_neox, head_size=self.head_size, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, ) def get_inputs(self) -> list: # Sample inputs to help pattern tracing T = 5 L = 4096 qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size) positions = empty_i64(T) cos_sin_cache = empty_bf16(L, self.head_size) inputs: list = [qkv, positions, cos_sin_cache] if _USE_LAYERNAME: inputs.append(_encode_layer_name(self.layer_name)) return inputs def _mk_pattern_with_layer_name_input(self, _ln): """Pattern/replacement with layer_name as an explicit input.""" def pattern(qkv, positions, cos_sin_cache, layer_name): q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) q, k = self.rope_matcher(positions, q, k, cos_sin_cache) q = q.view(-1, self.num_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size) v = v.view(-1, self.num_kv_heads, self.head_size_v) return torch.ops.vllm.unified_kv_cache_update(k, v, layer_name), q, k, v def replacement(qkv, positions, cos_sin_cache, layer_name): q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) q = q.view(-1, self.num_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size) v = v.view(-1, self.num_kv_heads, self.head_size_v) results = auto_functionalized( self.FUSED_OP, query=q, key=k, value=v, positions=positions, cos_sin_cache=cos_sin_cache, is_neox=self.is_neox, layer_name=layer_name, ) return results[0], results[1], results[2], v return pattern, replacement def _mk_pattern_with_layer_name_closure(self, _ln): """Pattern/replacement with layer_name as a closure constant.""" def pattern(qkv, positions, cos_sin_cache): q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) q, k = self.rope_matcher(positions, q, k, cos_sin_cache) q = q.view(-1, self.num_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size) v = v.view(-1, self.num_kv_heads, self.head_size_v) return torch.ops.vllm.unified_kv_cache_update(k, v, _ln), q, k, v def replacement(qkv, positions, cos_sin_cache): q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) q = q.view(-1, self.num_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size) v = v.view(-1, self.num_kv_heads, self.head_size_v) results = auto_functionalized( self.FUSED_OP, query=q, key=k, value=v, positions=positions, cos_sin_cache=cos_sin_cache, is_neox=self.is_neox, layer_name=_ln, ) return results[0], results[1], results[2], v return pattern, replacement def register(self, pm_pass: PatternMatcherPass) -> None: _ln = _encode_layer_name(self.layer_name) if _USE_LAYERNAME: pattern, replacement = self._mk_pattern_with_layer_name_input(_ln) else: pattern, replacement = self._mk_pattern_with_layer_name_closure(_ln) # NOTE: use view_to_reshape to unify view/reshape to simplify # pattern and increase matching opportunities def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule: gm = pm.fwd_only(*args, **kwargs) view_to_reshape(gm) return gm pm.register_replacement( pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass, ) class RopeKVCacheFusionPass(VllmPatternMatcherPass): """ This pass fuses the rotary embedding and KV cache update operations into a single fused kernel if available. It uses the pattern matcher and matches each layer manually, as strings cannot be wildcarded. This also lets us check support on attention layers upon registration instead of during pattern matching. This fusion eliminates the need for separate kernel launches and intermediate memory operations between the RoPE and cache update steps. """ @enable_fake_mode def __init__(self, config: VllmConfig) -> None: super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="rope_kv_cache_fusion_pass" ) cc = config.compilation_config self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num attn_layers = get_layers_from_vllm_config(config, Attention) # When _USE_LAYERNAME is enabled, layer_name is a wildcard so all # layers produce the same pattern — register once then break. for _, layer in attn_layers.items(): if layer.impl.fused_rope_kvcache_supported(): for is_neox in [True, False]: RopeReshapeKVCachePattern( layer=layer, is_neox=is_neox, ).register(self.patterns) if _USE_LAYERNAME: break self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) logger.debug("Replaced %s patterns", self.matched_count) def is_applicable_for_range(self, compile_range: Range) -> bool: # This pass works best for the small-batch decode setting. # For large-batch e.g. prefill, it is better to use two separate kernels # since they are compute bound and the fused kernels require further tuning. return compile_range.end <= self.max_token_num def uuid(self) -> str: return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)