test_sequence_parallelism.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import pytest
import torch

import vllm.envs as envs
8
from vllm.compilation.fusion import RMSNormQuantFusionPass
9
from vllm.compilation.fx_utils import find_auto_fn
10
from vllm.compilation.noop_elimination import NoOpEliminationPass
11
from vllm.compilation.post_cleanup import PostCleanupPass
12
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
13
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
14
15
from vllm.config import (
    CompilationConfig,
16
    CUDAGraphMode,
17
18
19
20
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
21
22
    get_current_vllm_config,
    set_current_vllm_config,
23
)
24
from vllm.distributed import tensor_model_parallel_all_reduce
25
26
27
28
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
29
from vllm.model_executor.layers.layernorm import RMSNorm
30
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
31
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
32
from vllm.platforms import current_platform
33
from vllm.utils.system_utils import update_environment_variables
34
35
36
37

from ..utils import multi_gpu_test
from .backend import TestBackend

38
FP8_DTYPE = current_platform.fp8_dtype()
39
40
41
42
43
44
45
46
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


47
48
class TestAllReduceRMSNormModel(torch.nn.Module):
    def __init__(self, hidden_size=16, eps=1e-6):
49
50
        super().__init__()
        self.hidden_size = hidden_size
51
52
53
        self.eps = eps
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
        self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
54

55
56
57
58
    def forward(self, x):
        z = torch.relu(x)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)
59

60
61
        z2 = torch.mm(y, self.w[0])
        x2 = tensor_model_parallel_all_reduce(z2)
62

63
        y2, resid = self.norm[1](x2, resid)
64

65
66
        z3 = torch.mm(y2, self.w[1])
        x3 = tensor_model_parallel_all_reduce(z3)
67

68
        y3, resid = self.norm[2](x3, resid)
69

70
71
        z4 = torch.mm(y3, self.w[2])
        x4 = tensor_model_parallel_all_reduce(z4)
72

73
74
        y4, resid = self.norm[3](x4, resid)
        return y4
75

76
77
78
79
80
    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [
81
            torch.ops.vllm.all_gather.default,
82
            torch.ops.vllm.reduce_scatter.default,
83
84
85
        ]

    def ops_in_model(self):
86
87
88
89
90
91
92
        if RMSNorm.enabled():
            return [
                torch.ops._C.rms_norm.default,
                torch.ops._C.fused_add_rms_norm.default,
            ]
        else:
            return []
93

94

95
96
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
    def __init__(self, hidden_size=16, eps=1e-6):
97
        super().__init__()
98
        self.vllm_config = get_current_vllm_config()
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
        self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
        self.w = [
            torch.rand(hidden_size, hidden_size)
            .to(dtype=current_platform.fp8_dtype())
            .t()
            for _ in range(3)
        ]

        self.fp8_linear = Fp8LinearOp(
            act_quant_static=True,
            act_quant_group_shape=GroupShape.PER_TENSOR,
113
        )
114
115
116
117
118
119
120
121
122
123
124

        self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]

    def forward(self, hidden_states):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(hidden_states)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)

        z2 = self.fp8_linear.apply(
            y, self.w[0], self.wscale[0], input_scale=self.scale[0]
125
        )
126

127
128
        x2 = tensor_model_parallel_all_reduce(z2)
        y2, resid = self.norm[1](x2, resid)
129

130
131
132
133
134
135
136
137
138
139
140
141
142
        z3 = self.fp8_linear.apply(
            y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
        )

        x3 = tensor_model_parallel_all_reduce(z3)
        y3, resid = self.norm[2](x3, resid)  # use resid here

        z4 = self.fp8_linear.apply(
            y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
        )
        x4 = tensor_model_parallel_all_reduce(z4)
        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
143
144

    def ops_in_model_after(self):
145
        return [
146
            torch.ops.vllm.all_gather.default,
147
148
149
150
151
152
            torch.ops.vllm.reduce_scatter.default,
        ]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
153
154
155
        ]

    def ops_in_model(self):
156
        if self.vllm_config.compilation_config.pass_config.enable_fusion:
157
            return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
158
        elif RMSNorm.enabled():
159
160
161
            return [
                torch.ops._C.fused_add_rms_norm.default,
            ]
162
163
164
165
166
167
        elif self.fp8_linear.quant_fp8.enabled():
            return [
                torch.ops._C.static_scaled_fp8_quant.default,
            ]
        else:
            return []
168
169


170
@multi_gpu_test(num_gpus=2)
171
172
173
174
175
176
177
178
179
180
181
@pytest.mark.parametrize(
    "test_model_cls, custom_ops",
    [
        (TestAllReduceRMSNormModel, "+rms_norm"),
        (TestAllReduceRMSNormModel, "-rms_norm"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
    ],
)
182
183
184
185
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
186
@pytest.mark.parametrize("enable_fusion", [True, False])
187
@pytest.mark.parametrize("dynamic", [False, True])
188
189
190
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
    test_model_cls: type[torch.nn.Module],
191
    custom_ops: str,
192
193
194
195
196
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
    enable_fusion: bool,
197
    dynamic: bool,
198
):
199
200
201
202
203
    num_processes = 2

    def run_torch_spawn(fn, nprocs):
        # need to use torch.mp.spawn otherwise will have problems with
        # torch.distributed and cuda
204
205
206
207
208
        torch.multiprocessing.spawn(
            fn,
            args=(
                num_processes,
                test_model_cls,
209
                custom_ops,
210
211
212
213
214
                batch_size,
                seq_len,
                hidden_size,
                dtype,
                enable_fusion,
215
                dynamic,
216
217
218
            ),
            nprocs=nprocs,
        )
219
220
221
222

    run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)


223
def sequence_parallelism_pass_on_test_model(
224
225
226
    local_rank: int,
    world_size: int,
    test_model_cls: type[torch.nn.Module],
227
    custom_ops: str,
228
229
230
231
232
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
    enable_fusion: bool,
233
    dynamic: bool,
234
):
235
236
237
238
239
240
241
    current_platform.seed_everything(0)

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

242
243
244
245
246
247
248
249
250
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
251
252
253
254
255
256

    # initialize distributed
    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # configure vllm config for SequenceParallelismPass
257
    custom_ops_list = custom_ops.split(",") if custom_ops else []
258
    compilation_config = CompilationConfig(
259
260
261
        splitting_ops=[],  # avoid automatic rms_norm enablement
        cudagraph_mode=CUDAGraphMode.NONE,  # avoid piecewise warnings
        custom_ops=custom_ops_list,
262
263
264
265
        pass_config=PassConfig(
            enable_sequence_parallelism=True,
            enable_fusion=enable_fusion,
            enable_noop=True,
266
        ),
267
    )  # NoOp needed for fusion
268
    device_config = DeviceConfig(device=torch.device("cuda"))
269
270
271

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
272
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
273
    model_config = ModelConfig(
274
275
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
276

277
278
279
280
    vllm_config = VllmConfig(
        model_config=model_config,
        device_config=device_config,
        compilation_config=compilation_config,
281
    )
282

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    with set_current_vllm_config(vllm_config):
        noop_pass = NoOpEliminationPass(vllm_config)
        sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
        assert (
            sequence_parallelism_pass.compilation_config.splitting_ops
            == vllm_config.compilation_config.splitting_ops
        )
        assert (
            sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
            == vllm_config.compilation_config.use_inductor_graph_partition
        )
        passes_for_backend: list[VllmInductorPass] = [
            noop_pass,
            sequence_parallelism_pass,
        ]
299

300
301
302
        if enable_fusion:
            fusion_pass = RMSNormQuantFusionPass(vllm_config)
            passes_for_backend.append(fusion_pass)
303

304
        passes_for_backend.append(cleanup_pass)
305

306
        backend = TestBackend(*passes_for_backend)
307

308
        model = test_model_cls(hidden_size)
309

310
        hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
311

312
313
314
315
316
        if dynamic:
            torch._dynamo.mark_dynamic(hidden_states, 0)

        compiled_model = torch.compile(model, backend=backend)
        compiled_model(hidden_states)
317

318
        assert sequence_parallelism_pass.matched_count == 4
319

320
321
        # In pre-nodes, all reduce should be there,
        # reduce scatter and all gather should not
322
323
        for op in model.ops_in_model_before():
            assert backend.op_count(op, before=True) == 4
324

325
326
        # In post-nodes, reduce scatter and all gather should be there,
        # all reduce should not
327
328
        for op in model.ops_in_model_after():
            assert backend.op_count(op, before=False) == 4
329

330
        for op in model.ops_in_model():
331
            find_auto_fn(backend.graph_post_pass.nodes, op)