test_async_tp.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10

import json

import pytest
import torch

import vllm.envs as envs
from vllm.compilation.collective_fusion import AsyncTPPass
11
12
from vllm.config import (
    CompilationConfig,
13
    CompilationMode,
14
15
16
17
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
18
    set_current_vllm_config,
19
20
21
22
23
24
25
26
27
)
from vllm.distributed import (
    tensor_model_parallel_all_gather,
    tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
28
from vllm.platforms import current_platform
29
from vllm.utils.system_utils import update_environment_variables
30
from vllm.utils.torch_utils import set_random_seed
31

32
33
from ...models.registry import HF_EXAMPLE_MODELS
from ...utils import (
34
35
36
37
    compare_two_settings,
    create_new_process_for_each_test,
    multi_gpu_test,
)
38
from ..backend import TestBackend
39

40
41
FP8_DTYPE = current_platform.fp8_dtype()

42
43
44
45
46
47
48
49
50
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


class TestMMRSModel(torch.nn.Module):
51
    def __init__(self, hidden_size=16, dtype=torch.float16):
52
53
        super().__init__()
        self.hidden_size = hidden_size
54
        self.dtype = dtype
55
56
57
        self.gate_proj = torch.nn.Parameter(
            torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
        )
58
59
60
61
62
63
        # Initialize weights
        torch.nn.init.normal_(self.gate_proj, std=0.02)

    def forward(self, hidden_states):
        """
        Forward pass implementing the mm + reduce scatter in the FX graph
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        """
        # Reshape input
        view = hidden_states.reshape(-1, self.hidden_size)

        # matrix multiplication
        permute = self.gate_proj.permute(1, 0)
        mm = torch.mm(view, permute)
        reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0)
        return reduce_scatter

    def ops_in_model_before(self):
        return [torch.ops.vllm.reduce_scatter.default]

    def ops_in_model_after(self):
        return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default]


class TestAGMMModel(torch.nn.Module):
83
    def __init__(self, hidden_size=16, dtype=torch.float16):
84
85
        super().__init__()
        self.hidden_size = hidden_size
86
        self.dtype = dtype
87
88
89
        self.weight = torch.nn.Parameter(
            torch.empty((hidden_size, hidden_size)), requires_grad=False
        )
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        # Initialize weights
        torch.nn.init.normal_(self.weight, std=0.02)

    def forward(self, hidden_states):
        """
        Forward pass implementing the mm + all gather in the FX graph
        """
        # Reshape input
        view = hidden_states.reshape(-1, self.hidden_size)
        all_gather = tensor_model_parallel_all_gather(view, dim=0)
        permute = self.weight.permute(1, 0)
        mm = torch.mm(all_gather, permute)
        return mm

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_gather.default]

    def ops_in_model_after(self):
        return [torch.ops.symm_mem.fused_all_gather_matmul.default]


111
112
113
114
115
class _BaseScaledMMModel(torch.nn.Module):
    def __init__(self, hidden_size=16, dtype=torch.float16):
        super().__init__()
        self.hidden_size = hidden_size
        self.dtype = dtype
116
117
118
119
120
        self.weight = (
            torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
121
122
123
124
125
126
127
128
129

        # Initialize scale_b for _scaled_mm.
        self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)


class TestScaledMMRSModel(_BaseScaledMMModel):
    def forward(self, input: torch.Tensor):
        """
        Forward pass implementing the scaled_mm + reduce scatter in the FX graph
130

131
132
133
        """
        fp8_input = input.to(FP8_DTYPE)
        scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
134
135
136
137
138
139
140
        scaled_mm = torch._scaled_mm(
            fp8_input,
            self.weight,
            scale_a=scale_a,
            scale_b=self.scale_b,
            out_dtype=self.dtype,
        )
141
142
143
144
145
146
147
        reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
        return reduce_scatter

    def ops_in_model_before(self):
        return [torch.ops.vllm.reduce_scatter.default]

    def ops_in_model_after(self):
148
        return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
149
150
151
152
153
154
155
156
157
158
159
160


class TestAGScaledMMModel(_BaseScaledMMModel):
    def forward(self, input: torch.Tensor):
        """
        Forward pass implementing the all gather + scaled_mm in the FX graph
        """
        # Reshape input
        fp8_input = input.to(FP8_DTYPE)
        all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)

        scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
161
162
163
164
165
166
167
        scaled_mm = torch._scaled_mm(
            all_gather,
            self.weight,
            scale_a=scale_a,
            scale_b=self.scale_b,
            out_dtype=self.dtype,
        )
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        return scaled_mm

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_gather.default]

    def ops_in_model_after(self):
        return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]


class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
    def forward(self, input: torch.Tensor):
        """
        Forward pass implementing the cutlass_scaled_mm + reduce scatter
        in the FX graph
182

183
184
185
        """
        fp8_input = input.to(FP8_DTYPE)
        scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
186
187
188
189
190
191
192
193
        mm_out = torch.empty(
            (fp8_input.shape[0], self.weight.shape[1]),
            dtype=self.dtype,
            device=input.device,
        )
        torch.ops._C.cutlass_scaled_mm(
            mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
        )
194
195
196
197
198
199
200
        reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
        return reduce_scatter

    def ops_in_model_before(self):
        return [torch.ops.vllm.reduce_scatter.default]

    def ops_in_model_after(self):
201
        return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
202
203
204
205
206


class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
    def forward(self, input: torch.Tensor):
        """
207
        Forward pass implementing the all gather + cutlass_scaled_mm
208
209
210
211
212
213
214
215
        in the FX graph
        """
        # Reshape input
        fp8_input = input.to(FP8_DTYPE)
        all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)

        scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)

216
217
218
219
220
221
222
223
        mm_out = torch.empty(
            (all_gather.shape[0], self.weight.shape[1]),
            dtype=self.dtype,
            device=all_gather.device,
        )
        torch.ops._C.cutlass_scaled_mm(
            mm_out, all_gather, self.weight, scale_a, self.scale_b, None
        )
224
225
226
227
228
229
230
231
232
        return mm_out

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_gather.default]

    def ops_in_model_after(self):
        return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]


233
@multi_gpu_test(num_gpus=2)
234
235
236
237
238
239
240
241
242
243
244
@pytest.mark.parametrize(
    "test_model",
    [
        TestMMRSModel,
        TestAGMMModel,
        TestScaledMMRSModel,
        TestAGScaledMMModel,
        TestCutlassScaledMMRSModel,
        TestAGCutlassScaledMMModel,
    ],
)
245
246
247
248
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
249
@pytest.mark.parametrize("dynamic", [True, False])
250
251
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_async_tp_pass_replace(
252
253
254
255
256
257
    test_model: str,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
    dynamic: bool,
258
259
260
261
262
263
264
265
266
267
268
):
    if (
        test_model
        in (
            TestScaledMMRSModel,
            TestAGScaledMMModel,
            TestCutlassScaledMMRSModel,
            TestAGCutlassScaledMMModel,
        )
        and dtype == torch.float16
    ):
269
        pytest.skip(
270
            "Only bf16 high precision output types are supported for "
271
272
273
            "per-token (row-wise) scaling"
        )

274
275
276
277
278
    num_processes = 2

    def run_torch_spawn(fn, nprocs):
        # need to use torch.mp.spawn otherwise will have problems with
        # torch.distributed and cuda
279
280
        torch.multiprocessing.spawn(
            fn,
281
282
283
284
285
286
287
288
289
            args=(
                num_processes,
                test_model,
                batch_size,
                seq_len,
                hidden_size,
                dtype,
                dynamic,
            ),
290
291
            nprocs=nprocs,
        )
292
293
294
295

    run_torch_spawn(async_tp_pass_on_test_model, num_processes)


296
297
298
299
300
301
302
303
def async_tp_pass_on_test_model(
    local_rank: int,
    world_size: int,
    test_model_cls: torch.nn.Module,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
304
    dynamic: bool,
305
):
306
    set_random_seed(0)
307
308
309
310
311
312

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

313
314
315
316
317
318
319
320
321
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
322
323
324
325
326
327
328

    # initialize distributed
    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # configure vllm config for SequenceParallelismPass
    vllm_config = VllmConfig()
329
330
    vllm_config.compilation_config = CompilationConfig(
        pass_config=PassConfig(
331
            fuse_gemm_comms=True,
332
333
        ),
    )
334
335
336
337
    vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
338
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
339
340
341
    vllm_config.model_config = ModelConfig(
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
342
343
344

    async_tp_pass = AsyncTPPass(vllm_config)

345
346
347
348
    # Set the global vllm_config for TestBackend which calls
    # get_current_vllm_config()
    with set_current_vllm_config(vllm_config):
        backend = TestBackend(async_tp_pass)
349

350
351
352
353
354
355
356
357
        assert (
            async_tp_pass.compilation_config.splitting_ops
            == vllm_config.compilation_config.splitting_ops
        )
        assert (
            async_tp_pass.compilation_config.use_inductor_graph_partition
            == vllm_config.compilation_config.use_inductor_graph_partition
        )
358

359
360
361
362
363
        model = test_model_cls(hidden_size, dtype)  # Pass dtype to model constructor

        hidden_states = torch.randn(
            (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
        )
364

365
366
        if dynamic:
            torch._dynamo.mark_dynamic(hidden_states, 0)
367

368
369
        compiled_model = torch.compile(model, backend=backend)
        compiled_model(hidden_states)
370

371
        assert async_tp_pass.matched_count == 1
372

373
374
375
        # In pre-nodes, all gather or reduce scatter should exist,
        # fused_matmul_reduce_scatter or fused_all_gather_matmul should not
        backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
376

377
378
379
        # In post-nodes, fused_matmul_reduce_scatter or \
        # fused_all_gather_matmul should exist
        backend.check_after_ops(model.ops_in_model_after())
380
381
382


@create_new_process_for_each_test()
383
384
385
386
@pytest.mark.parametrize(
    "model_id",
    ["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
)
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"])
@pytest.mark.parametrize("eager_mode", [False, True])
def test_async_tp_pass_correctness(
    model_id: str,
    tp_size: int,
    async_tp_enabled: bool,
    distributed_backend: str,
    eager_mode: bool,
    num_gpus_available: int,
):
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")
    model_info.check_available_online(on_fail="skip")

    pp_size = 1
    if num_gpus_available < tp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")

    common_args = [
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if eager_mode:
        common_args.append("--enforce-eager")

    compilation_config = {
419
        "mode": CompilationMode.VLLM_COMPILE,
420
421
        "compile_sizes": [2, 4, 8],
        "splitting_ops": [],
422
        "pass_config": {"fuse_gemm_comms": async_tp_enabled},
423
424
    }

425
    async_tp_args = [
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
        "--distributed-executor-backend",
        distributed_backend,
        "--compilation_config",
        json.dumps(compilation_config),
    ]

    tp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
        "--distributed-executor-backend",
        "mp",
    ]

443
    compare_two_settings(model_id, async_tp_args, tp_args, method="generate")