test_async_tp.py 11.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8


import pytest
import torch

import vllm.envs as envs
9
10
11
12
13
from tests.compile.backend import TestBackend
from tests.utils import (
    multi_gpu_test,
)
from vllm.compilation.passes.fusion.collective_fusion import AsyncTPPass
14
15
16
17
18
19
from vllm.config import (
    CompilationConfig,
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
20
    set_current_vllm_config,
21
22
23
24
25
26
27
28
29
)
from vllm.distributed import (
    tensor_model_parallel_all_gather,
    tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
30
from vllm.platforms import current_platform
31
from vllm.utils.system_utils import update_environment_variables
32
from vllm.utils.torch_utils import set_random_seed
33

34
35
FP8_DTYPE = current_platform.fp8_dtype()

36
37
38
39
40
41
42
43
44
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


class TestMMRSModel(torch.nn.Module):
45
    def __init__(self, hidden_size=16, dtype=torch.float16):
46
47
        super().__init__()
        self.hidden_size = hidden_size
48
        self.dtype = dtype
49
50
51
        self.gate_proj = torch.nn.Parameter(
            torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
        )
52
53
54
55
56
57
        # Initialize weights
        torch.nn.init.normal_(self.gate_proj, std=0.02)

    def forward(self, hidden_states):
        """
        Forward pass implementing the mm + reduce scatter in the FX graph
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        """
        # Reshape input
        view = hidden_states.reshape(-1, self.hidden_size)

        # matrix multiplication
        permute = self.gate_proj.permute(1, 0)
        mm = torch.mm(view, permute)
        reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0)
        return reduce_scatter

    def ops_in_model_before(self):
        return [torch.ops.vllm.reduce_scatter.default]

    def ops_in_model_after(self):
        return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default]


class TestAGMMModel(torch.nn.Module):
77
    def __init__(self, hidden_size=16, dtype=torch.float16):
78
79
        super().__init__()
        self.hidden_size = hidden_size
80
        self.dtype = dtype
81
82
83
        self.weight = torch.nn.Parameter(
            torch.empty((hidden_size, hidden_size)), requires_grad=False
        )
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        # Initialize weights
        torch.nn.init.normal_(self.weight, std=0.02)

    def forward(self, hidden_states):
        """
        Forward pass implementing the mm + all gather in the FX graph
        """
        # Reshape input
        view = hidden_states.reshape(-1, self.hidden_size)
        all_gather = tensor_model_parallel_all_gather(view, dim=0)
        permute = self.weight.permute(1, 0)
        mm = torch.mm(all_gather, permute)
        return mm

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_gather.default]

    def ops_in_model_after(self):
        return [torch.ops.symm_mem.fused_all_gather_matmul.default]


105
106
107
108
109
class _BaseScaledMMModel(torch.nn.Module):
    def __init__(self, hidden_size=16, dtype=torch.float16):
        super().__init__()
        self.hidden_size = hidden_size
        self.dtype = dtype
110
111
112
113
114
        self.weight = (
            torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
            .contiguous()
            .transpose(0, 1)
        )
115
116
117
118
119
120
121
122
123

        # Initialize scale_b for _scaled_mm.
        self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)


class TestScaledMMRSModel(_BaseScaledMMModel):
    def forward(self, input: torch.Tensor):
        """
        Forward pass implementing the scaled_mm + reduce scatter in the FX graph
124

125
126
127
        """
        fp8_input = input.to(FP8_DTYPE)
        scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
128
129
130
131
132
133
134
        scaled_mm = torch._scaled_mm(
            fp8_input,
            self.weight,
            scale_a=scale_a,
            scale_b=self.scale_b,
            out_dtype=self.dtype,
        )
135
136
137
138
139
140
141
        reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
        return reduce_scatter

    def ops_in_model_before(self):
        return [torch.ops.vllm.reduce_scatter.default]

    def ops_in_model_after(self):
142
        return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
143
144
145
146
147
148
149
150
151
152
153
154


class TestAGScaledMMModel(_BaseScaledMMModel):
    def forward(self, input: torch.Tensor):
        """
        Forward pass implementing the all gather + scaled_mm in the FX graph
        """
        # Reshape input
        fp8_input = input.to(FP8_DTYPE)
        all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)

        scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
155
156
157
158
159
160
161
        scaled_mm = torch._scaled_mm(
            all_gather,
            self.weight,
            scale_a=scale_a,
            scale_b=self.scale_b,
            out_dtype=self.dtype,
        )
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        return scaled_mm

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_gather.default]

    def ops_in_model_after(self):
        return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]


class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
    def forward(self, input: torch.Tensor):
        """
        Forward pass implementing the cutlass_scaled_mm + reduce scatter
        in the FX graph
176

177
178
179
        """
        fp8_input = input.to(FP8_DTYPE)
        scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
180
181
182
183
184
185
186
187
        mm_out = torch.empty(
            (fp8_input.shape[0], self.weight.shape[1]),
            dtype=self.dtype,
            device=input.device,
        )
        torch.ops._C.cutlass_scaled_mm(
            mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
        )
188
189
190
191
192
193
194
        reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
        return reduce_scatter

    def ops_in_model_before(self):
        return [torch.ops.vllm.reduce_scatter.default]

    def ops_in_model_after(self):
195
        return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
196
197
198
199
200


class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
    def forward(self, input: torch.Tensor):
        """
201
        Forward pass implementing the all gather + cutlass_scaled_mm
202
203
204
205
206
207
208
209
        in the FX graph
        """
        # Reshape input
        fp8_input = input.to(FP8_DTYPE)
        all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)

        scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)

210
211
212
213
214
215
216
217
        mm_out = torch.empty(
            (all_gather.shape[0], self.weight.shape[1]),
            dtype=self.dtype,
            device=all_gather.device,
        )
        torch.ops._C.cutlass_scaled_mm(
            mm_out, all_gather, self.weight, scale_a, self.scale_b, None
        )
218
219
220
221
222
223
224
225
226
        return mm_out

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_gather.default]

    def ops_in_model_after(self):
        return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]


227
@multi_gpu_test(num_gpus=2)
228
229
230
231
232
233
234
235
236
237
238
@pytest.mark.parametrize(
    "test_model",
    [
        TestMMRSModel,
        TestAGMMModel,
        TestScaledMMRSModel,
        TestAGScaledMMModel,
        TestCutlassScaledMMRSModel,
        TestAGCutlassScaledMMModel,
    ],
)
239
240
241
242
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
243
@pytest.mark.parametrize("dynamic", [True, False])
244
245
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_async_tp_pass_replace(
246
247
248
249
250
251
    test_model: str,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
    dynamic: bool,
252
253
254
255
256
257
258
259
260
261
262
):
    if (
        test_model
        in (
            TestScaledMMRSModel,
            TestAGScaledMMModel,
            TestCutlassScaledMMRSModel,
            TestAGCutlassScaledMMModel,
        )
        and dtype == torch.float16
    ):
263
        pytest.skip(
264
            "Only bf16 high precision output types are supported for "
265
266
267
            "per-token (row-wise) scaling"
        )

268
269
270
271
272
    num_processes = 2

    def run_torch_spawn(fn, nprocs):
        # need to use torch.mp.spawn otherwise will have problems with
        # torch.distributed and cuda
273
274
        torch.multiprocessing.spawn(
            fn,
275
276
277
278
279
280
281
282
283
            args=(
                num_processes,
                test_model,
                batch_size,
                seq_len,
                hidden_size,
                dtype,
                dynamic,
            ),
284
285
            nprocs=nprocs,
        )
286
287
288
289

    run_torch_spawn(async_tp_pass_on_test_model, num_processes)


290
291
292
293
294
295
296
297
def async_tp_pass_on_test_model(
    local_rank: int,
    world_size: int,
    test_model_cls: torch.nn.Module,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
298
    dynamic: bool,
299
):
300
    set_random_seed(0)
301
302
303
304
305
306

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

307
308
309
310
311
312
313
314
315
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
316
317
318
319
320
321

    # initialize distributed
    init_distributed_environment()

    # configure vllm config for SequenceParallelismPass
    vllm_config = VllmConfig()
322
323
    vllm_config.compilation_config = CompilationConfig(
        pass_config=PassConfig(
324
            fuse_gemm_comms=True,
325
326
        ),
    )
327
328
329
330
    vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
331
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
332
333
334
    vllm_config.model_config = ModelConfig(
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
335

336
    with set_current_vllm_config(vllm_config):
337
338
339
        initialize_model_parallel(tensor_model_parallel_size=world_size)

        async_tp_pass = AsyncTPPass(vllm_config)
340
        backend = TestBackend(async_tp_pass)
341

342
343
344
345
346
347
348
349
        assert (
            async_tp_pass.compilation_config.splitting_ops
            == vllm_config.compilation_config.splitting_ops
        )
        assert (
            async_tp_pass.compilation_config.use_inductor_graph_partition
            == vllm_config.compilation_config.use_inductor_graph_partition
        )
350

351
352
353
354
355
        model = test_model_cls(hidden_size, dtype)  # Pass dtype to model constructor

        hidden_states = torch.randn(
            (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
        )
356

357
358
        if dynamic:
            torch._dynamo.mark_dynamic(hidden_states, 0)
359

360
361
        compiled_model = torch.compile(model, backend=backend)
        compiled_model(hidden_states)
362

363
        assert async_tp_pass.matched_count == 1
364

365
366
367
        # In pre-nodes, all gather or reduce scatter should exist,
        # fused_matmul_reduce_scatter or fused_all_gather_matmul should not
        backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
368

369
370
371
        # In post-nodes, fused_matmul_reduce_scatter or \
        # fused_all_gather_matmul should exist
        backend.check_after_ops(model.ops_in_model_after())