"docs/pages/backends/vllm/README.md" did not exist on "3188c70a3b3291c614dd1f19076d4759a19a73a5"
test_fusion_all_reduce.py 8.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec

import pytest
import torch

import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass
10
11
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
12
from vllm.compilation.post_cleanup import PostCleanupPass
13
14
15
16
17
18
19
20
from vllm.config import (
    CompilationConfig,
    CompilationLevel,
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
)
21
from vllm.distributed import tensor_model_parallel_all_reduce
22
23
24
25
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
26
from vllm.model_executor.layers.layernorm import RMSNorm
27
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
28
29
30
    GroupShape,
    QuantFP8,
)
31
32
33
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables

34
from ..utils import has_module_attribute, multi_gpu_test
35
36
37
38
from .backend import TestBackend


class TestAllReduceRMSNormModel(torch.nn.Module):
39
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = RMSNorm(hidden_size, eps)

    def forward(self, hidden_states, residual):
        view = hidden_states.reshape(-1, self.hidden_size)
        all_reduce = tensor_model_parallel_all_reduce(view)
        norm = self.norm(all_reduce)
        return norm

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
59
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = RMSNorm(hidden_size, eps)

    def forward(self, hidden_states, residual):
        view = hidden_states.reshape(-1, self.hidden_size)
        all_reduce = tensor_model_parallel_all_reduce(view)
        norm, _ = self.norm(all_reduce, residual)
        return norm

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


78
79
80
81
82
83
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = RMSNorm(hidden_size, eps)
84
        self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
85
        self.scale = torch.rand(1, dtype=torch.float32)
86
        self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
87
88
89
90
91

    def forward(self, hidden_states, residual):
        view = hidden_states.reshape(-1, self.hidden_size)
        all_reduce = tensor_model_parallel_all_reduce(view)
        norm_output, residual_output = self.norm(all_reduce, residual)
92
93
94
        torch.ops._C.static_scaled_fp8_quant(
            self.output, norm_output.contiguous(), self.scale
        )
95
96
97
98
99
100
101
102
        return self.output, residual_output

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
103
            torch.ops._C.static_scaled_fp8_quant.default,
104
105
106
107
108
109
110
111
112
113
        ]


class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = RMSNorm(hidden_size, eps)
        self.scale = torch.rand(1, dtype=torch.float32)
114
        self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
115
116
117
118
119

        round_up = lambda x, y: (x + y - 1) // y * y
        rounded_m = round_up(token_num, 128)
        scale_n = hidden_size // 16
        rounded_n = round_up(scale_n, 4)
120
        self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
121
122
123
124
125
126

    def forward(self, hidden_states, residual):
        view = hidden_states.reshape(-1, self.hidden_size)
        all_reduce = tensor_model_parallel_all_reduce(view)
        norm_output, residual_output = self.norm(all_reduce, residual)
        norm_output = norm_output.reshape(-1, norm_output.shape[-1])
127
128
129
        torch.ops._C.scaled_fp4_quant(
            self.output, norm_output, self.output_scale, self.scale
        )
130
131
132
133
134
135
136
137
        return self.output, residual_output, self.output_scale

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
138
            torch.ops._C.scaled_fp4_quant.default,
139
140
141
        ]


142
@multi_gpu_test(num_gpus=2)
143
144
145
146
147
148
149
150
@pytest.mark.parametrize(
    "test_model",
    [
        TestAllReduceRMSNormModel,
        TestAllReduceFusedAddRMSNormModel,
        TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
        # TODO: Enable with torch==2.8.0
        # TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
151
152
    ],
)
153
154
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
155
@pytest.mark.parametrize("hidden_size", [16])
156
@pytest.mark.parametrize("dtype", [torch.bfloat16])
157
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
158
159
160
161
@pytest.mark.skipif(
    not find_spec("flashinfer")
    or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
    reason="flashinfer is not found or flashinfer "
162
163
164
165
166
167
168
169
170
    "is not compiled with trtllm_allreduce_fusion",
)
def test_all_reduce_fusion_pass_replace(
    test_model: torch.nn.Module,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
):
171
    num_processes = 2
172
173
174
175
176
177
178
179
    if (
        test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
        and not current_platform.has_device_capability(100)
    ):
        pytest.skip(
            "Skip as nvfp4 is only supported on "
            "devices with compute capability 10.0 (Blackwell)"
        )
180
181

    def run_torch_spawn(fn, nprocs):
182
183
184
185
186
        torch.multiprocessing.spawn(
            fn,
            args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
            nprocs=nprocs,
        )
187
188
189
190

    run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)


191
192
193
194
195
196
197
198
199
def all_reduce_fusion_pass_on_test_model(
    local_rank: int,
    world_size: int,
    test_model_cls: torch.nn.Module,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
):
200
201
202
203
204
205
206
    current_platform.seed_everything(0)

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

207
208
209
210
211
212
213
214
215
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
216
217
218
219

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

220
221
222
223
224
    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
            level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"]
        )
    )
225
    vllm_config.compilation_config.pass_config = PassConfig(
226
227
        enable_fi_allreduce_fusion=True, enable_noop=True
    )
228
229
230
231
    vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
232
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
233
234
235
    vllm_config.model_config = ModelConfig(
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
236

237
    all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
238
239
    noop_pass = NoOpEliminationPass(vllm_config)
    func_pass = FixFunctionalizationPass(vllm_config)
240
    cleanup_pass = PostCleanupPass(vllm_config)
241

242
    backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
243

244
245
    token_num = batch_size * seq_len
    model = test_model_cls(hidden_size, token_num)
246

247
248
    hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
    residual = torch.randn((token_num, hidden_size), requires_grad=False)
249
250
251
252

    compiled_model = torch.compile(model, backend=backend)
    compiled_model(hidden_states, residual)

253
    assert all_reduce_fusion_pass.matched_count == 1
254
255
256
    backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
    backend.check_after_ops(model.ops_in_model_after())
    del all_reduce_fusion_pass