"vllm/vscode:/vscode.git/clone" did not exist on "aec90b84aa1acc29d7e5aa575b83c2617740a114"
Unverified Commit 2ff4e511 authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files
parent 95642441
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest
import torch
import vllm.envs as envs
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer
from vllm.compilation.passes.fusion.act_quant_fusion import (
......@@ -31,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()
......@@ -198,23 +200,82 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
return [torch.ops.aten.slice_scatter.default]
MODELS = [
TestSiluMul,
TestFusedAddRMSNorm,
TestRotaryEmbedding,
TestRotaryEmbeddingSliceScatter,
]
class TestFunctionWithMutatedArgsAndReturn(torch.nn.Module):
OP_REGISTERED = False
def __init__(self):
super().__init__()
self.register_test_custom_op()
@classmethod
def register_test_custom_op(cls):
if not cls.OP_REGISTERED:
def function_with_mutated_args_and_return_impl(
x: torch.Tensor,
) -> torch.Tensor:
ret = x + 1
x.add_(2)
return ret
def function_with_mutated_args_and_return_fake(
x: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op(
op_name="function_with_mutated_args_and_return",
op_func=function_with_mutated_args_and_return_impl,
mutates_args=["x"],
fake_impl=function_with_mutated_args_and_return_fake,
)
cls.OP_REGISTERED = True
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Clone x to avoid mutating the original tensor
ret = torch.ops.vllm.function_with_mutated_args_and_return(x)
return x, ret
def example_inputs(self, num_tokens=32):
hidden_states = torch.randn(num_tokens)
return (hidden_states,)
def ops_in_model(self, do_fusion):
return [torch.ops.vllm.function_with_mutated_args_and_return.default]
def ops_not_in_model(self):
return []
MODELS_AND_DO_FUSION = {
TestSiluMul: [True, False],
TestFusedAddRMSNorm: [True, False],
TestRotaryEmbedding: [False],
TestRotaryEmbeddingSliceScatter: [False],
TestFunctionWithMutatedArgsAndReturn: [False],
}
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
@pytest.mark.parametrize(
"model_class, do_fusion",
[
(model_class, do_fusion)
for model_class, fusions in MODELS_AND_DO_FUSION.items()
for do_fusion in fusions
],
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="Only test on cuda and rocm platform",
)
def test_fix_functionalization(
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(0)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
......@@ -246,8 +307,14 @@ def test_fix_functionalization(
backend_no_func = TestBackend(*passes)
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
inputs_func = model.example_inputs()
inputs_no_func = copy.deepcopy(inputs_func)
model_func = model_class()
model_no_func = copy.deepcopy(model_func)
model_func = torch.compile(model_func, backend=backend_func)
model_no_func = torch.compile(model_no_func, backend=backend_no_func)
model_func(*inputs_func)
model_no_func(*inputs_no_func)
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
......@@ -265,3 +332,8 @@ def test_fix_functionalization(
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())
# TODO (Rohan138): compare the outputs from model_func and model_no_func
# currently runs into errors while comparing `TestFusedAddRMSNorm`
# Linked issue: https://github.com/vllm-project/vllm/issues/34996
# torch.testing.assert_close(outputs_func, outputs_no_func)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import vllm.config
from tests.compile.backend import TestBackend
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP
from vllm.compilation.passes.fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.utility.scatter_split_replace import (
ScatterSplitReplacementPass,
)
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
from vllm.config import (
CacheConfig,
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update
FP8_DTYPE = current_platform.fp8_dtype()
class QKRoPEKVCacheTestModel(torch.nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
attn_backend: AttentionBackendEnum,
num_heads: int,
num_kv_heads: int,
head_size: int,
is_neox: bool,
dtype: torch.dtype,
device: torch.device,
prefix: str = "model.layers.0.self_attn.attn",
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.block_size = vllm_config.cache_config.block_size
self.q_size = num_heads * head_size
self.kv_size = num_kv_heads * head_size
self.is_neox = is_neox
self.dtype = dtype
self.device = device
self.layer_name = prefix
self.rotary_emb = RotaryEmbedding(
head_size,
rotary_dim=head_size,
max_position_embeddings=4096,
base=10000,
is_neox_style=is_neox,
dtype=self.dtype,
)
# Whether to check for the RoPE custom op or component index_select
self.enable_rope_custom_op = self.rotary_emb.enabled()
# Register layer metadata for the fusion pass via Attention.
self.attn = Attention(
num_heads=num_heads,
head_size=head_size,
scale=1.0 / head_size**0.5,
num_kv_heads=num_kv_heads,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
prefix=prefix,
attn_backend=attn_backend.get_class(),
)
self.attn_backend: type[AttentionBackend] = self.attn.get_attn_backend()
assert not self.attn_backend.forward_includes_kv_cache_update, (
f"Attention backend {self.attn_backend} does not support fuse_rope_kvcache."
)
self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device)
kv_cache_dtype_str = vllm_config.cache_config.cache_dtype
self.kv_cache_dtype = (
FP8_DTYPE if kv_cache_dtype_str.startswith("fp8") else self.dtype
)
# Initialize attn MetadataBuilder
self.builder = self.attn.attn_backend.get_builder_cls()(
kv_cache_spec=AttentionSpec(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=head_size,
dtype=self.kv_cache_dtype,
),
layer_names=[self.attn.layer_name],
vllm_config=vllm_config,
device=device,
)
def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata:
"""Initialize attention metadata."""
# Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
batch_spec, self.block_size, self.device, arange_block_indices=True
)
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
# Fetch the attention backend and kv cache shape and stride order
attn_backend = self.attn.attn_backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
# Create dummy KV cache
raw_tensor = torch.zeros(
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order)
self.attn.kv_cache = [kv_cache]
# Build attn metadata
attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
return attn_metadata
def forward(
self, qkv: torch.Tensor, positions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Create copy so inplace ops do not modify the original tensors
qkv = qkv.clone()
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
# Instead of a full forward pass, match only the KV cache update op here
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size)
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
k, v, self.layer_name
)
return q, k, v, kv_cache_dummy_dep
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops = []
if self.enable_rope_custom_op:
ops.append(ROTARY_OP)
else:
ops.append(INDEX_SELECT_OP)
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
return ops
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]
@pytest.mark.parametrize(
"attn_backend",
[
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.ROCM_ATTN,
],
)
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
@pytest.mark.parametrize("num_heads", [64])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("is_neox", [True, False])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.skipif(
not is_aiter_found_and_supported(),
reason="Only test on ROCm with AITER installed and supported",
)
def test_rope_kvcache_fusion(
attn_backend: AttentionBackendEnum,
enable_rope_custom_op: bool,
num_heads: int,
num_kv_heads: int,
head_size: int,
block_size: int,
is_neox: bool,
dtype: torch.dtype,
kv_cache_dtype: str,
monkeypatch: pytest.MonkeyPatch,
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(0)
custom_ops: list[str] = []
if enable_rope_custom_op:
custom_ops.append("+rotary_embedding")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
cache_config=CacheConfig(
block_size=block_size,
cache_dtype=kv_cache_dtype,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(
fuse_rope_kvcache=True,
eliminate_noops=True,
),
),
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()
model = QKRoPEKVCacheTestModel(
vllm_config=vllm_config,
attn_backend=attn_backend,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
is_neox=is_neox,
dtype=dtype,
device=torch.get_default_device(),
)
fusion_pass = RopeKVCacheFusionPass(vllm_config)
passes = [
NoOpEliminationPass(vllm_config),
SplitCoalescingPass(vllm_config),
ScatterSplitReplacementPass(vllm_config),
fusion_pass,
PostCleanupPass(vllm_config),
]
backend = TestBackend(*passes)
T = 5
qkv = torch.randn(
T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype
)
pos = torch.arange(T, dtype=torch.long)
qkv_unfused = qkv.clone()
pos_unfused = pos.clone()
with set_forward_context(None, vllm_config):
forward_context = get_forward_context()
attn_metadata = model.build_attn_metadata(T)
forward_context.slot_mapping = {
model.layer_name: attn_metadata.slot_mapping
}
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_unfused = attn_layer.kv_cache[forward_context.virtual_engine]
del dummy
torch._dynamo.mark_dynamic(qkv, 0)
torch._dynamo.mark_dynamic(pos, 0)
with set_forward_context(None, vllm_config):
model_fused = torch.compile(model, backend=backend)
forward_context = get_forward_context()
attn_metadata = model_fused.build_attn_metadata(T)
forward_context.slot_mapping = {
model.layer_name: attn_metadata.slot_mapping
}
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_fused = attn_layer.kv_cache[forward_context.virtual_engine]
del dummy
assert fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
# Cannot compare fp8_* directly here, cast to model dtype instead
torch.testing.assert_close(
kv_cache_unfused.view(dtype),
kv_cache_fused.view(dtype),
atol=ATOL,
rtol=RTOL,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn as nn
import vllm
from tests.compile.backend import TestBackend
from vllm.compilation.passes.utility.scatter_split_replace import (
ScatterSplitReplacementPass,
)
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
class ScatterSplitReplacementModel(nn.Module):
"""Model with a rope+getitem+slice_scatter+split_with_sizes sequence."""
def __init__(
self,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
):
super().__init__()
self.q_size = num_heads * head_size
self.kv_size = num_kv_heads * head_size
self.rotary_emb = RotaryEmbedding(
head_size,
rotary_dim=head_size,
max_position_embeddings=4096,
base=10000,
is_neox_style=True,
dtype=dtype,
)
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
# Create copy so inplace ops do not modify the original tensors
qkv = qkv.clone()
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
q = q + 1
k = k + 2
v = v + 3
return q, k, v
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
return [
torch.ops.aten.slice_scatter.default,
torch.ops.aten.split_with_sizes.default,
torch.ops.aten.getitem.default,
]
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
return [torch.ops.aten.getitem.default]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scatter_split_replace(dtype):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(0)
num_heads = 8
num_kv_heads = 4
head_size = 64
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rotary_embedding"],
),
)
with vllm.config.set_current_vllm_config(vllm_config):
# ScatterSplitReplacementPass requires SplitCoalescingPass to be run before it
coalesce_pass = SplitCoalescingPass(vllm_config)
replace_pass = ScatterSplitReplacementPass(vllm_config)
passes = [coalesce_pass, replace_pass]
backend = TestBackend(*passes)
model = ScatterSplitReplacementModel(num_heads, num_kv_heads, head_size, dtype)
T = 5
qkv = torch.randn(
T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype
)
pos = torch.arange(T, dtype=torch.long)
qkv_eager = qkv.clone()
pos_eager = pos.clone()
result_eager = model(qkv_eager, pos_eager)
torch._dynamo.mark_dynamic(qkv, 0)
torch._dynamo.mark_dynamic(pos, 0)
model_compiled = torch.compile(model, backend=backend)
result_compiled = model_compiled(qkv, pos)
for eager, compiled in zip(result_eager, result_compiled):
torch.testing.assert_close(eager, compiled)
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1
......@@ -1518,6 +1518,45 @@ class rocm_aiter_ops:
query = query.view(query_shape)
key = key.view(key_shape)
@staticmethod
def triton_rope_and_cache(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
flash_layout: bool,
apply_scale: bool,
):
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
cos, sin = cos_sin_cache.chunk(2, dim=-1)
fused_qk_rope_reshape_and_cache(
query,
key,
value,
key_cache,
value_cache,
layer_slot_mapping,
positions,
cos,
sin,
k_scale,
v_scale,
is_neox,
flash_layout=flash_layout,
apply_scale=apply_scale,
q_out=query,
k_out=key,
output_zeros=False,
)
@staticmethod
def batched_gemm_a16wfp4(
X: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops import auto_functionalized
from torch._inductor.fx_passes.post_grad import view_to_reshape
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.attention import (
Attention,
get_attention_context,
)
from vllm.utils.torch_utils import direct_register_custom_op
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherRotaryEmbedding,
)
from .rms_quant_fusion import (
empty_bf16,
empty_i64,
)
logger = init_logger(__name__)
def fused_rope_and_unified_kv_cache_update_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
) -> torch.Tensor:
"""
This impl fetches the KV cache and slot mapping from the forward context,
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
attn_layer.impl.do_rope_and_kv_cache_update(
attn_layer,
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def fused_rope_and_unified_kv_cache_update_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
layer_name: str = "",
) -> torch.Tensor:
return torch.empty(0, device=query.device, dtype=query.dtype)
direct_register_custom_op(
op_name="fused_rope_and_unified_kv_cache_update",
op_func=fused_rope_and_unified_kv_cache_update_impl,
mutates_args=["query", "key"],
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
)
class RopeReshapeKVCachePattern:
"""
This pattern matches the following unfused inplace ops:
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
and replaces it with the fused inplace op:
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
q, k, v, positions, cos_sin_cache, is_neox, layer_name
)
"""
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
def __init__(
self,
layer: Attention,
is_neox: bool,
) -> None:
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.num_kv_heads = layer.num_kv_heads
self.head_size = layer.head_size
self.head_size_v = layer.head_size_v
self.is_neox = is_neox
self.q_size = self.num_heads * self.head_size
self.k_size = self.num_kv_heads * self.head_size
self.v_size = self.num_kv_heads * self.head_size_v
self.rope_matcher = MatcherRotaryEmbedding(
is_neox=self.is_neox,
head_size=self.head_size,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
def get_inputs(self) -> list[torch.Tensor]:
# Sample inputs to help pattern tracing
T = 5
L = 4096
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
positions = empty_i64(T)
cos_sin_cache = empty_bf16(L, self.head_size)
return [
qkv,
positions,
cos_sin_cache,
]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
return dummy, q, k, v
def replacement(
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
results = auto_functionalized(
self.FUSED_OP,
query=q,
key=k,
value=v,
positions=positions,
cos_sin_cache=cos_sin_cache,
is_neox=self.is_neox,
layer_name=self.layer_name,
)
return results[0], results[1], results[2], v
# NOTE: use view_to_reshape to unify view/reshape to simplify
# pattern and increase matching opportunities
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
gm = pm.fwd_only(*args, **kwargs)
view_to_reshape(gm)
return gm
pm.register_replacement(
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
)
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
"""
This pass fuses the rotary embedding and KV cache update operations
into a single fused kernel if available.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
This fusion eliminates the need for separate kernel launches and
intermediate memory operations between the RoPE and cache update steps.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rope_kv_cache_fusion_pass"
)
cc = config.compilation_config
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
attn_layers = get_layers_from_vllm_config(config, Attention)
for _, layer in attn_layers.items():
if layer.impl.fused_rope_kvcache_supported():
for is_neox in [True, False]:
RopeReshapeKVCachePattern(
layer=layer,
is_neox=is_neox,
).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass works best for the small-batch decode setting.
# For large-batch e.g. prefill, it is better to use two separate kernels
# since they are compute bound and the fused kernels require further tuning.
return compile_range.end <= self.max_token_num
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)
......@@ -28,7 +28,9 @@ if current_platform.is_cuda_alike():
from .fusion.attn_quant_fusion import AttnFusionPass
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
from .fusion.sequence_parallelism import SequenceParallelismPass
from .utility.scatter_split_replace import ScatterSplitReplacementPass
from .utility.split_coalescing import SplitCoalescingPass
if current_platform.is_cuda():
......@@ -136,6 +138,11 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
if self.pass_config.fuse_rope_kvcache:
self.passes += [SplitCoalescingPass(config)]
self.passes += [ScatterSplitReplacementPass(config)]
self.passes += [RopeKVCacheFusionPass(config)]
if self.pass_config.fuse_attn_quant:
self.passes += [AttnFusionPass(config)]
......
......@@ -162,6 +162,24 @@ class FixFunctionalizationPass(VllmInductorPass):
"position_ids",
)
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
elif (
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
and at_target
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
):
mutated_args = {
1: "query",
2: "key",
}
self.defunctionalize(graph, node, mutated_args=mutated_args)
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
elif (
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
and at_target
== torch.ops.vllm.function_with_mutated_args_and_return.default
):
mutated_args = {1: "x"}
self.defunctionalize(graph, node, mutated_args=mutated_args)
else:
continue # skip the count
......@@ -208,13 +226,20 @@ class FixFunctionalizationPass(VllmInductorPass):
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
) -> None:
"""
Replace all getitem users of the auto-functionalized node with the
Replace mutated getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
# Some functionalized nodes may return both a result at getitem[0]
# as well as mutated args at getitem[1:...]
if idx == 0:
assert idx not in mutated_args, (
f"result at getitem[0] should not be in mutated_args for {node}"
)
continue
arg = mutated_args[idx]
arg = node.kwargs[arg] if isinstance(arg, str) else arg
user.replace_all_uses_with(arg)
......@@ -257,10 +282,20 @@ class FixFunctionalizationPass(VllmInductorPass):
with graph.inserting_before(node):
function = node.args[0]
if args is None:
graph.call_function(function, kwargs=node.kwargs)
fn_node = graph.call_function(function, kwargs=node.kwargs)
else:
# Args passed as strings refer to items in node.kwargs
args = tuple(
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
)
graph.call_function(function, args=args)
fn_node = graph.call_function(function, args=args)
# If the function returns a value as well as mutating args inplace,
# the functionalized node will have a getitem[0] user that holds this value
# Replace getitem[0] user of the auto-functionalized node
# with the new defunctionalized node directly if it exists
users = self.getitem_users(node)
if 0 in users:
user = users[0]
user.replace_all_uses_with(fn_node)
self._remove(user)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
assignment if there are no users for the inplace tensor written to by
the slice_scatter call.
The inplace rotary_embedding custom op takes in mutable query and key inputs
that are split+getitem outputs of a single qkv tensor.
When functionalized, we fetch the rotated query and key from the functionalized op
using `getitem` calls. However, we also write to the qkv tensor inplace using a
`slice_scatter`, then split the inplace tensor to get the output tensors again.
Instead, if the inplace tensor has no subsequent users, we can just replace the
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
This is already done in fix_functionalization::FixFunctionalizationPass, but
writing a custom pass for it before defunctionalization allows matching against the
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
"""
import operator
import torch
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from ..fx_utils import is_func
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class ScatterSplitReplacementPass(VllmInductorPass):
"""Replace getitem+slice_scatter+split nodes with a single getitem when
the inplace subtensor written to by the slice_scatter has no other users.
Here's an example graph with q_size = 512, kv_size = 64:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
q = operator.getitem(split_with_sizes_2, 0)
k = operator.getitem(split_with_sizes_2, 1)
v = operator.getitem(split_with_sizes_2, 2)
After this pass, this sequence of nodes is replaced with:
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
q = operator.getitem(at, 1)
k = operator.getitem(at, 2)
v = operator.getitem(split_with_sizes_1, 2)
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = {}
for user in node.users:
if is_func(user, operator.getitem):
getitem_nodes[user.args[1]] = user
if (
is_func(query, operator.getitem)
and is_func(key, operator.getitem)
and query.args[0] == key.args[0]
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
and all(
is_func(user, torch.ops.aten.slice_scatter.default)
for getitem_node in getitem_nodes.values()
for user in getitem_node.users
)
):
# Pattern where query and key are slices of a qkv tensor.
# While functionalized, results at [1] and [2] are scattered
# back into qkv, then split again to get query and key.
# If the inplace tensor has no other users, we can replace
# the slice_scatter+split nodes with the original results.
for user in getitem_nodes[1].users:
slice_scatter_1_node = user
if not is_func(
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
):
continue
for user in getitem_nodes[2].users:
slice_scatter_2_node = user
if not is_func(
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
):
continue
for user in slice_scatter_2_node.users:
split_node = user
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
continue
split_getitem_users = {}
for user in split_node.users:
if is_func(user, operator.getitem):
split_getitem_users[user.args[1]] = user
# Replace query node
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
graph.erase_node(split_getitem_users[0])
# Replace key node
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
graph.erase_node(split_getitem_users[1])
# Redirect value node to original qkv tensor
split_getitem_users[2].replace_input_with(split_node, query.args[0])
# Erase unused nodes
graph.erase_node(split_node)
graph.erase_node(slice_scatter_2_node)
graph.erase_node(slice_scatter_1_node)
count += 1
logger.debug("Eliminated %d slice_scatter+split nodes", count)
......@@ -127,6 +127,13 @@ class PassConfig:
# ROCm/AITER specific fusions
fuse_act_padding: bool = Field(default=None)
"""Fuse the custom RMSNorm + padding ops."""
fuse_rope_kvcache: bool = Field(default=None)
"""Fuse the QK rope + KV cache ops."""
rope_kvcache_fusion_max_token_num: int = 256
"""The threshold for ROCm AITER RoPE+KVCache fusion e.g. for small batch decode.
Larger batch sizes e.g. during prefill will use the unfused kernels.
"""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
......@@ -198,6 +205,7 @@ class PassConfig:
"fuse_gemm_comms",
"fuse_allreduce_rms",
"fuse_act_padding",
"fuse_rope_kvcache",
mode="wrap",
)
@classmethod
......@@ -243,6 +251,12 @@ class PassConfig:
"The fusion will be disabled."
)
self.fuse_act_padding = False
if self.fuse_rope_kvcache and not current_platform.is_rocm():
logger.warning_once(
"KV cache fusion currently only enabled on ROCm. "
"The fusion will be disabled."
)
self.fuse_rope_kvcache = False
class DynamicShapesType(str, enum.Enum):
......@@ -824,6 +838,19 @@ class CompilationConfig:
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if self.pass_config.fuse_rope_kvcache:
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
logger.warning(
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
)
self.pass_config.fuse_rope_kvcache = False
else:
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
is_torch_equal_or_newer("2.9.0.dev")
......
......@@ -1401,6 +1401,20 @@ class VllmConfig:
"allreduce-rms fusion will be enabled for all num_tokens."
)
if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = (
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
)
if max_token_num is not None:
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below rope+kvcache fusion threshold, "
"rope+kvcache fusion enabled for num_tokens <= %d.",
compile_range_end,
)
if compilation_config.compile_ranges_split_points is not None:
for x in compilation_config.compile_ranges_split_points:
assert isinstance(x, int)
......
......@@ -570,11 +570,11 @@ direct_register_custom_op(
def get_attention_context(
layer_name: str,
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor]:
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
"""Extract attention context for a given layer.
This helper function extracts the attention metadata, attention layer
instance, and KV cache tensor for a specific layer.
instance, KV cache tensor, and slot mapping for a specific layer.
Args:
layer_name: The name/identifier of the attention layer.
......@@ -585,6 +585,7 @@ def get_attention_context(
no metadata available
- attn_layer: The attention layer instance (Attention or MLAAttention)
- kv_cache: The KV cache tensor for current virtual engine
- slot_mapping: The slot mapping for this specific layer
Note: attn_metadata may be None, but attn_layer and kv_cache are always
extracted from the forward context.
......@@ -593,9 +594,14 @@ def get_attention_context(
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_layer = forward_context.no_compile_layers[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
return attn_metadata, attn_layer, kv_cache
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
@maybe_transfer_kv_layer
......@@ -605,7 +611,7 @@ def unified_attention(
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
return output
......@@ -636,15 +642,7 @@ def unified_kv_cache_update(
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
......@@ -691,7 +689,7 @@ def unified_attention_with_output(
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, self, kv_cache = get_attention_context(layer_name)
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
self.impl.forward(
self,
......
......@@ -40,8 +40,8 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
layer_name: str = args[layer_name_index]
# Extract attention context (layer-specific metadata, layer, and kv_cache)
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
connector = get_kv_transfer_group()
if attn_metadata is None or not connector.has_connector_metadata():
return func(*args, **kwargs)
......
......@@ -828,7 +828,7 @@ def unified_mla_attention(
k_pe: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
return output
......@@ -862,7 +862,7 @@ def unified_mla_attention_with_output(
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
layer.forward_impl(
q,
kv_c_normed,
......
......@@ -723,6 +723,33 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
"""
return False
def fused_rope_kvcache_supported(self):
"""
Does this attention implementation support RoPE+KVCache fusion.
This is used by the RopeKVCacheFusionPass to only fuse the RoPE ops
with the KV cache update for implementations that support it.
"""
return False
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
"""
If `fused_rope_kvcache_supported` returns True, this method will be called
by torch.ops.vllm.fused_rope_and_unified_kv_cache_update
to perform the inplace RoPE and KV cache update.
"""
raise NotImplementedError
class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MLA attention implementation with forward_mqa and forward_mha methods."""
......
......@@ -11,7 +11,6 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention.attention import get_attention_context
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
......@@ -1290,11 +1289,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
attn_metadata, _, _ = get_attention_context(layer.layer_name)
if attn_metadata is None:
# Profiling run.
return
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
......@@ -1303,11 +1297,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
......@@ -1319,8 +1308,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
# We may calculate per token quant scale in
# reshape_and_cache_shuffle_triton which might differ from
# vllm's style when shuffle layout is used.
k_scale = attn_metadata.k_scale
v_scale = attn_metadata.v_scale
k_scale = layer._k_scale
v_scale = layer._v_scale
assert k_scale is not None and v_scale is not None, (
"k_scale and v_scale are required for shuffled update"
)
......
......@@ -5,6 +5,7 @@
import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
......@@ -207,3 +208,42 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)
......@@ -7,6 +7,7 @@ from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
......@@ -415,3 +416,46 @@ class RocmAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache,
layer.num_kv_heads, # type: ignore[attr-defined]
layer.head_size, # type: ignore[attr-defined]
)
flash_layout = False
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)
......@@ -7,6 +7,7 @@ from typing import ClassVar
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
......@@ -596,3 +597,42 @@ class TritonAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
def fused_rope_kvcache_supported(self):
return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment