test_fusion_all_reduce.py 11.1 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec

import pytest
import torch

import vllm.envs as envs
9
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
10
from vllm.compilation.collective_fusion import AllReduceFusionPass
11
12
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
13
from vllm.compilation.post_cleanup import PostCleanupPass
14
15
from vllm.config import (
    CompilationConfig,
16
    CompilationMode,
17
18
19
20
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
21
    set_current_vllm_config,
22
)
23
from vllm.distributed import tensor_model_parallel_all_reduce
24
25
26
27
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
28
from vllm.model_executor.layers.layernorm import RMSNorm
29
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
30
    Fp8LinearOp,
31
32
    GroupShape,
)
33
from vllm.platforms import current_platform
34
from vllm.utils.system_utils import update_environment_variables
35
from vllm.utils.torch_utils import set_random_seed
36

37
38
from ...utils import has_module_attribute, multi_gpu_test
from ..backend import TestBackend
39
40
41


class TestAllReduceRMSNormModel(torch.nn.Module):
42
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
43
44
45
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
46
47
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
        self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
48

49
50
51
52
53
    def forward(self, x):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(x)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)
54

55
56
        z2 = torch.mm(y, self.w[0])
        x2 = tensor_model_parallel_all_reduce(z2)
57

58
        y2, resid = self.norm[1](x2, resid)
59

60
61
        z3 = torch.mm(y2, self.w[1])
        x3 = tensor_model_parallel_all_reduce(z3)
62

63
64
65
66
        y3, resid = self.norm[2](x3, resid)

        z4 = torch.mm(y3, self.w[2])
        x4 = tensor_model_parallel_all_reduce(z4)
67

68
69
        y4, resid = self.norm[3](x4, resid)
        return y4
70
71
72
73
74
75
76
77

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


78
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
79
80
81
82
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
        self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
        self.w = [
            torch.rand(hidden_size, hidden_size)
            .to(dtype=current_platform.fp8_dtype())
            .t()
            for _ in range(3)
        ]

        self.fp8_linear = Fp8LinearOp(
            act_quant_static=True,
            act_quant_group_shape=GroupShape.PER_TENSOR,
        )

        self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]

    def forward(self, hidden_states):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(hidden_states)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)

        z2 = self.fp8_linear.apply(
            y, self.w[0], self.wscale[0], input_scale=self.scale[0]
        )

        x2 = tensor_model_parallel_all_reduce(z2)
        y2, resid = self.norm[1](x2, resid)

        z3 = self.fp8_linear.apply(
            y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
114
        )
115
116
117
118
119
120
121
122
123
124

        x3 = tensor_model_parallel_all_reduce(z3)
        y3, resid = self.norm[2](x3, resid)  # use resid here

        z4 = self.fp8_linear.apply(
            y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
        )
        x4 = tensor_model_parallel_all_reduce(z4)
        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
125
126
127
128
129
130
131

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
132
133
134
            torch.ops._C.static_scaled_fp8_quant.default
            if self.fp8_linear.quant_fp8.enabled()
            else torch.ops.aten.reciprocal.default,
135
136
137
138
139
140
141
142
        ]


class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]

        self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
        self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
        wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
        self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]

        wq_gen, wscale_gen = zip(
            *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
        )
        self.wq, self.wscale = list(wq_gen), list(wscale_gen)
        print(f"{self.wq=}, {self.wscale=}")

    def forward(self, hidden_states):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(hidden_states)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)

        yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
        z2 = cutlass_scaled_fp4_mm(
            yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
        )

        x2 = tensor_model_parallel_all_reduce(z2)
        y2, resid = self.norm[1](x2, resid)

        yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
        z3 = cutlass_scaled_fp4_mm(
            yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
173
        )
174
175
176
177
178
179
180
181
182
183
184

        x3 = tensor_model_parallel_all_reduce(z3)
        y3, resid = self.norm[2](x3, resid)  # use resid here

        yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
        z4 = cutlass_scaled_fp4_mm(
            yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
        )
        x4 = tensor_model_parallel_all_reduce(z4)
        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
185
186
187
188
189
190
191

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
192
            torch.ops._C.scaled_fp4_quant.default,
193
194
195
        ]


196
@multi_gpu_test(num_gpus=2)
197
@pytest.mark.parametrize(
198
    "test_model, enable_quant_fp8_custom_op",
199
    [
200
201
202
203
        (TestAllReduceRMSNormModel, False),
        (TestAllReduceRMSNormStaticQuantFP8Model, True),
        (TestAllReduceRMSNormStaticQuantFP8Model, False),
        (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
204
205
    ],
)
206
207
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
208
@pytest.mark.parametrize("hidden_size", [64])
209
@pytest.mark.parametrize("dtype", [torch.bfloat16])
210
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
211
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
212
213
214
215
@pytest.mark.skipif(
    not find_spec("flashinfer")
    or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
    reason="flashinfer is not found or flashinfer "
216
217
218
219
220
221
222
223
    "is not compiled with trtllm_allreduce_fusion",
)
def test_all_reduce_fusion_pass_replace(
    test_model: torch.nn.Module,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
224
225
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
226
):
227
    num_processes = 2
228
229
230
231
232
233
234
235
    if (
        test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
        and not current_platform.has_device_capability(100)
    ):
        pytest.skip(
            "Skip as nvfp4 is only supported on "
            "devices with compute capability 10.0 (Blackwell)"
        )
236
237

    def run_torch_spawn(fn, nprocs):
238
239
        torch.multiprocessing.spawn(
            fn,
240
241
242
243
244
245
246
247
248
249
            args=(
                num_processes,
                test_model,
                batch_size,
                seq_len,
                hidden_size,
                dtype,
                enable_rms_norm_custom_op,
                enable_quant_fp8_custom_op,
            ),
250
251
            nprocs=nprocs,
        )
252
253
254
255

    run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)


256
257
258
259
260
261
262
263
def all_reduce_fusion_pass_on_test_model(
    local_rank: int,
    world_size: int,
    test_model_cls: torch.nn.Module,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
264
265
    enable_rms_norm_custom_op,
    enable_quant_fp8_custom_op,
266
):
267
    set_random_seed(0)
268
269
270
271
272
273

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

274
275
276
277
278
279
280
281
282
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
283
284
285
286

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

287
288
289
290
291
292
    custom_ops = []
    if enable_rms_norm_custom_op:
        custom_ops.append("+rms_norm")
    if enable_quant_fp8_custom_op:
        custom_ops.append("+quant_fp8")

293
294
    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
295
            mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
296
297
        )
    )
298
    vllm_config.compilation_config.pass_config = PassConfig(
299
        fuse_allreduce_rms=True, eliminate_noops=True
300
    )
301
    vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
302
    vllm_config.parallel_config.rank = local_rank  # Setup rank for debug path
303
304
305

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
306
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
307
308
309
    vllm_config.model_config = ModelConfig(
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
310
311
312
313
314
315
316
317
318
    with set_current_vllm_config(vllm_config):
        all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
        noop_pass = NoOpEliminationPass(vllm_config)
        func_pass = FixFunctionalizationPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)

        backend = TestBackend(
            noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
        )
319

320
321
        token_num = batch_size * seq_len
        model = test_model_cls(hidden_size, token_num)
322

323
        hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
324

325
326
        compiled_model = torch.compile(model, backend=backend)
        compiled_model(hidden_states)
327

328
329
330
331
332
333
        assert all_reduce_fusion_pass.matched_count == 4, (
            f"{all_reduce_fusion_pass.matched_count=}"
        )
        backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
        backend.check_after_ops(model.ops_in_model_after())
        del all_reduce_fusion_pass