test_scatter_split_replace.py 3.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch
import torch.nn as nn

import vllm
from tests.compile.backend import TestBackend
from vllm.compilation.passes.utility.scatter_split_replace import (
    ScatterSplitReplacementPass,
)
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
16
17
18
from vllm.platforms import current_platform

DEVICE_TYPE = current_platform.device_type
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66


class ScatterSplitReplacementModel(nn.Module):
    """Model with a rope+getitem+slice_scatter+split_with_sizes sequence."""

    def __init__(
        self,
        num_heads: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
    ):
        super().__init__()
        self.q_size = num_heads * head_size
        self.kv_size = num_kv_heads * head_size

        self.rotary_emb = RotaryEmbedding(
            head_size,
            rotary_dim=head_size,
            max_position_embeddings=4096,
            base=10000,
            is_neox_style=True,
            dtype=dtype,
        )

    def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
        # Create copy so inplace ops do not modify the original tensors
        qkv = qkv.clone()
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        q = q + 1
        k = k + 2
        v = v + 3
        return q, k, v

    def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
        return [
            torch.ops.aten.slice_scatter.default,
            torch.ops.aten.split_with_sizes.default,
            torch.ops.aten.getitem.default,
        ]

    def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
        return [torch.ops.aten.getitem.default]


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scatter_split_replace(dtype):
67
    torch.set_default_device(DEVICE_TYPE)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    torch.set_default_dtype(dtype)
    torch.manual_seed(0)

    num_heads = 8
    num_kv_heads = 4
    head_size = 64

    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+rotary_embedding"],
        ),
    )
    with vllm.config.set_current_vllm_config(vllm_config):
        # ScatterSplitReplacementPass requires SplitCoalescingPass to be run before it
        coalesce_pass = SplitCoalescingPass(vllm_config)
        replace_pass = ScatterSplitReplacementPass(vllm_config)
        passes = [coalesce_pass, replace_pass]
        backend = TestBackend(*passes)

        model = ScatterSplitReplacementModel(num_heads, num_kv_heads, head_size, dtype)

        T = 5
        qkv = torch.randn(
            T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype
        )
        pos = torch.arange(T, dtype=torch.long)

        qkv_eager = qkv.clone()
        pos_eager = pos.clone()
        result_eager = model(qkv_eager, pos_eager)

        torch._dynamo.mark_dynamic(qkv, 0)
        torch._dynamo.mark_dynamic(pos, 0)

        model_compiled = torch.compile(model, backend=backend)
        result_compiled = model_compiled(qkv, pos)

        for eager, compiled in zip(result_eager, result_compiled):
            torch.testing.assert_close(eager, compiled)

        assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
        assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1