test_fusion_all_reduce.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec

import pytest
import torch

import vllm.envs as envs
9
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
10
from vllm.compilation.collective_fusion import AllReduceFusionPass
11
12
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
13
from vllm.compilation.post_cleanup import PostCleanupPass
14
15
from vllm.config import (
    CompilationConfig,
16
    CompilationMode,
17
18
19
20
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
21
    set_current_vllm_config,
22
)
23
from vllm.distributed import tensor_model_parallel_all_reduce
24
25
26
27
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
28
from vllm.model_executor.layers.layernorm import RMSNorm
29
30
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
31
)
32
from vllm.platforms import current_platform
33
from vllm.utils.system_utils import update_environment_variables
34
from vllm.utils.torch_utils import set_random_seed
35

36
from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test
37
from ..backend import TestBackend
38
39
40


class TestAllReduceRMSNormModel(torch.nn.Module):
41
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
42
43
44
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
45
46
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
        self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
47

48
49
50
51
52
    def forward(self, x):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(x)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)
53

54
55
        z2 = torch.mm(y, self.w[0])
        x2 = tensor_model_parallel_all_reduce(z2)
56

57
        y2, resid = self.norm[1](x2, resid)
58

59
60
        z3 = torch.mm(y2, self.w[1])
        x3 = tensor_model_parallel_all_reduce(z3)
61

62
63
64
65
        y3, resid = self.norm[2](x3, resid)

        z4 = torch.mm(y3, self.w[2])
        x4 = tensor_model_parallel_all_reduce(z4)
66

67
68
        y4, resid = self.norm[3](x4, resid)
        return y4
69
70
71
72
73
74
75
76

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


77
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
78
79
    quant_key = kFp8StaticTensorSym

80
81
82
83
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
84
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
85
86
87
88
89
90
91
        self.fp8_linear_layers = [
            TestFP8Layer(
                weight_shape=(hidden_size, hidden_size),
                activation_quant_key=self.quant_key,
                weight_quant_key=self.quant_key,
            )
            for i in range(3)
92
93
94
95
96
97
98
99
        ]

    def forward(self, hidden_states):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(hidden_states)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)

100
        z2 = self.fp8_linear_layers[0](y)
101
102
103
104

        x2 = tensor_model_parallel_all_reduce(z2)
        y2, resid = self.norm[1](x2, resid)

105
        z3 = self.fp8_linear_layers[1](y2)
106
107
108
109

        x3 = tensor_model_parallel_all_reduce(z3)
        y3, resid = self.norm[2](x3, resid)  # use resid here

110
111
        z4 = self.fp8_linear_layers[2](y3)

112
113
114
        x4 = tensor_model_parallel_all_reduce(z4)
        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
115
116
117
118
119
120
121

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
122
            torch.ops._C.static_scaled_fp8_quant.default
123
            if self.fp8_linear_layers[0].is_quant_fp8_enabled()
124
            else torch.ops.aten.reciprocal.default,
125
126
127
128
129
130
131
132
        ]


class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]

        self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
        self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
        wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
        self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]

        wq_gen, wscale_gen = zip(
            *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
        )
        self.wq, self.wscale = list(wq_gen), list(wscale_gen)
        print(f"{self.wq=}, {self.wscale=}")

    def forward(self, hidden_states):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(hidden_states)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)

        yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
        z2 = cutlass_scaled_fp4_mm(
            yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
        )

        x2 = tensor_model_parallel_all_reduce(z2)
        y2, resid = self.norm[1](x2, resid)

        yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
        z3 = cutlass_scaled_fp4_mm(
            yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
163
        )
164
165
166
167
168
169
170
171
172
173
174

        x3 = tensor_model_parallel_all_reduce(z3)
        y3, resid = self.norm[2](x3, resid)  # use resid here

        yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
        z4 = cutlass_scaled_fp4_mm(
            yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
        )
        x4 = tensor_model_parallel_all_reduce(z4)
        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
175
176
177
178
179
180
181

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
182
            torch.ops._C.scaled_fp4_quant.default,
183
184
185
        ]


186
@multi_gpu_test(num_gpus=2)
187
@pytest.mark.parametrize(
188
    "test_model, enable_quant_fp8_custom_op",
189
    [
190
191
192
193
        (TestAllReduceRMSNormModel, False),
        (TestAllReduceRMSNormStaticQuantFP8Model, True),
        (TestAllReduceRMSNormStaticQuantFP8Model, False),
        (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
194
195
    ],
)
196
197
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
198
@pytest.mark.parametrize("hidden_size", [64])
199
@pytest.mark.parametrize("dtype", [torch.bfloat16])
200
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
201
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
202
203
204
205
@pytest.mark.skipif(
    not find_spec("flashinfer")
    or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
    reason="flashinfer is not found or flashinfer "
206
207
208
209
210
211
212
213
    "is not compiled with trtllm_allreduce_fusion",
)
def test_all_reduce_fusion_pass_replace(
    test_model: torch.nn.Module,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
214
215
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
216
):
217
    num_processes = 2
218
219
220
221
222
223
224
225
    if (
        test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
        and not current_platform.has_device_capability(100)
    ):
        pytest.skip(
            "Skip as nvfp4 is only supported on "
            "devices with compute capability 10.0 (Blackwell)"
        )
226
227

    def run_torch_spawn(fn, nprocs):
228
229
        torch.multiprocessing.spawn(
            fn,
230
231
232
233
234
235
236
237
238
239
            args=(
                num_processes,
                test_model,
                batch_size,
                seq_len,
                hidden_size,
                dtype,
                enable_rms_norm_custom_op,
                enable_quant_fp8_custom_op,
            ),
240
241
            nprocs=nprocs,
        )
242
243
244
245

    run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)


246
247
248
249
250
251
252
253
def all_reduce_fusion_pass_on_test_model(
    local_rank: int,
    world_size: int,
    test_model_cls: torch.nn.Module,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
254
255
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
256
):
257
    set_random_seed(0)
258
259
260
261
262
263

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

264
265
266
267
268
269
270
271
272
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
273
274
275
276

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

277
278
279
280
281
282
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")

283
284
    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
285
            mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
286
287
        )
    )
288
    vllm_config.compilation_config.pass_config = PassConfig(
289
        fuse_allreduce_rms=True, eliminate_noops=True
290
    )
291
    vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
292
    vllm_config.parallel_config.rank = local_rank  # Setup rank for debug path
293
294
295

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
296
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
297
298
299
    vllm_config.model_config = ModelConfig(
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
300
301
302
303
304
305
306
307
308
    with set_current_vllm_config(vllm_config):
        all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
        noop_pass = NoOpEliminationPass(vllm_config)
        func_pass = FixFunctionalizationPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)

        backend = TestBackend(
            noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
        )
309

310
311
        token_num = batch_size * seq_len
        model = test_model_cls(hidden_size, token_num)
312

313
        hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
314

315
316
        compiled_model = torch.compile(model, backend=backend)
        compiled_model(hidden_states)
317

318
319
320
321
322
323
        assert all_reduce_fusion_pass.matched_count == 4, (
            f"{all_reduce_fusion_pass.matched_count=}"
        )
        backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
        backend.check_after_ops(model.ops_in_model_after())
        del all_reduce_fusion_pass