test_sequence_parallelism.py 11.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import pytest
import torch

import vllm.envs as envs
8
9
10
11
12
13
14
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, multi_gpu_test
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
15
16
from vllm.config import (
    CompilationConfig,
17
    CUDAGraphMode,
18
19
20
21
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
22
23
    get_current_vllm_config,
    set_current_vllm_config,
24
)
25
from vllm.config.utils import Range
26
from vllm.distributed import tensor_model_parallel_all_reduce
27
28
29
30
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
31
from vllm.model_executor.layers.layernorm import RMSNorm
32
33
34
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
35
from vllm.platforms import current_platform
36
from vllm.utils.system_utils import update_environment_variables
37
from vllm.utils.torch_utils import set_random_seed
38

39
40
DEVICE_TYPE = current_platform.device_type

41
42
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")

43
FP8_DTYPE = current_platform.fp8_dtype()
44
45
46
47
48
49
50
51
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


52
53
class TestAllReduceRMSNormModel(torch.nn.Module):
    def __init__(self, hidden_size=16, eps=1e-6):
54
55
        super().__init__()
        self.hidden_size = hidden_size
56
57
58
        self.eps = eps
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
        self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
59

60
61
62
63
    def forward(self, x):
        z = torch.relu(x)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)
64

65
66
        z2 = torch.mm(y, self.w[0])
        x2 = tensor_model_parallel_all_reduce(z2)
67

68
        y2, resid = self.norm[1](x2, resid)
69

70
71
        z3 = torch.mm(y2, self.w[1])
        x3 = tensor_model_parallel_all_reduce(z3)
72

73
        y3, resid = self.norm[2](x3, resid)
74

75
76
        z4 = torch.mm(y3, self.w[2])
        x4 = tensor_model_parallel_all_reduce(z4)
77

78
79
        y4, resid = self.norm[3](x4, resid)
        return y4
80

81
82
83
84
85
    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [
86
            torch.ops.vllm.all_gather.default,
87
            torch.ops.vllm.reduce_scatter.default,
88
89
90
        ]

    def ops_in_model(self):
91
92
93
        return (
            [torch.ops.vllm_ir.rms_norm]
            + [
94
95
                torch.ops._C.fused_add_rms_norm.default,
            ]
96
97
98
            if RMSNorm.enabled()
            else []
        )
99

100

101
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
102
103
    quant_key = kFp8StaticTensorSym

104
    def __init__(self, hidden_size=16, eps=1e-6):
105
        super().__init__()
106
        self.vllm_config = get_current_vllm_config()
107
108
109
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
110
111
112
113
114
        self.fp8_linear_layers = [
            TestFP8Layer(
                weight_shape=(hidden_size, hidden_size),
                activation_quant_key=self.quant_key,
                weight_quant_key=self.quant_key,
115
                input_dtype=self.vllm_config.model_config.dtype,
116
117
            )
            for i in range(3)
118
119
120
121
122
123
124
125
        ]

    def forward(self, hidden_states):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(hidden_states)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)

126
        z2 = self.fp8_linear_layers[0](y)
127

128
129
        x2 = tensor_model_parallel_all_reduce(z2)
        y2, resid = self.norm[1](x2, resid)
130

131
        z3 = self.fp8_linear_layers[1](y2)
132
133
134
135

        x3 = tensor_model_parallel_all_reduce(z3)
        y3, resid = self.norm[2](x3, resid)  # use resid here

136
        z4 = self.fp8_linear_layers[2](y3)
137
138
139
        x4 = tensor_model_parallel_all_reduce(z4)
        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
140
141

    def ops_in_model_after(self):
142
        return [
143
            torch.ops.vllm.all_gather.default,
144
145
146
147
148
149
            torch.ops.vllm.reduce_scatter.default,
        ]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
150
151
152
        ]

    def ops_in_model(self):
153
        if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
154
            return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
155
        elif RMSNorm.enabled():
156
157
158
            return [
                torch.ops._C.fused_add_rms_norm.default,
            ]
159
        elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers):
160
161
162
163
164
            return [
                torch.ops._C.static_scaled_fp8_quant.default,
            ]
        else:
            return []
165
166


167
@multi_gpu_test(num_gpus=2)
168
169
170
171
172
173
174
175
176
177
178
@pytest.mark.parametrize(
    "test_model_cls, custom_ops",
    [
        (TestAllReduceRMSNormModel, "+rms_norm"),
        (TestAllReduceRMSNormModel, "-rms_norm"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
    ],
)
179
180
181
182
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
183
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
184
@pytest.mark.parametrize("dynamic", [False, True])
185
186
187
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
    test_model_cls: type[torch.nn.Module],
188
    custom_ops: str,
189
190
191
192
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
193
    fuse_norm_quant: bool,
194
    dynamic: bool,
195
):
196
197
198
199
200
    num_processes = 2

    def run_torch_spawn(fn, nprocs):
        # need to use torch.mp.spawn otherwise will have problems with
        # torch.distributed and cuda
201
202
203
204
205
        torch.multiprocessing.spawn(
            fn,
            args=(
                num_processes,
                test_model_cls,
206
                custom_ops,
207
208
209
210
                batch_size,
                seq_len,
                hidden_size,
                dtype,
211
                fuse_norm_quant,
212
                dynamic,
213
214
215
            ),
            nprocs=nprocs,
        )
216
217
218
219

    run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)


220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def test_sequence_parallelism_pass_requires_full_graph_compilation():
    vllm_config = VllmConfig()
    vllm_config.compilation_config.use_inductor_graph_partition = False
    vllm_config.compilation_config.splitting_ops = [
        "vllm::unified_attention_with_output"
    ]

    sequence_parallelism_pass = object.__new__(SequenceParallelismPass)
    sequence_parallelism_pass.compilation_config = vllm_config.compilation_config
    sequence_parallelism_pass.min_token_num = 1

    with pytest.raises(
        AssertionError,
        match="SequenceParallelismPass requires full-graph compilation",
    ):
        sequence_parallelism_pass.is_applicable_for_range(Range(start=8, end=8))


238
def sequence_parallelism_pass_on_test_model(
239
240
241
    local_rank: int,
    world_size: int,
    test_model_cls: type[torch.nn.Module],
242
    custom_ops: str,
243
244
245
246
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
247
    fuse_norm_quant: bool,
248
    dynamic: bool,
249
):
250
    set_random_seed(0)
251

252
    device = torch.device(f"{DEVICE_TYPE}:{local_rank}")
253
    torch.accelerator.set_device_index(device)
254
255
256
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

257
258
259
260
261
262
263
264
265
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
266
267
268
269
270

    # initialize distributed
    init_distributed_environment()

    # configure vllm config for SequenceParallelismPass
271
    custom_ops_list = custom_ops.split(",") if custom_ops else []
272
    compilation_config = CompilationConfig(
273
274
275
        splitting_ops=[],  # avoid automatic rms_norm enablement
        cudagraph_mode=CUDAGraphMode.NONE,  # avoid piecewise warnings
        custom_ops=custom_ops_list,
276
        pass_config=PassConfig(
277
278
279
            enable_sp=True,
            fuse_norm_quant=fuse_norm_quant,
            eliminate_noops=True,
280
        ),
281
    )  # NoOp needed for fusion
282
    device_config = DeviceConfig(device=torch.device(DEVICE_TYPE))
283
284
285

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
286
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
287
    model_config = ModelConfig(
288
289
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
290

291
292
293
294
    vllm_config = VllmConfig(
        model_config=model_config,
        device_config=device_config,
        compilation_config=compilation_config,
295
    )
296

297
    with set_current_vllm_config(vllm_config):
298
        initialize_model_parallel(tensor_model_parallel_size=world_size)
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        noop_pass = NoOpEliminationPass(vllm_config)
        sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
        assert (
            sequence_parallelism_pass.compilation_config.splitting_ops
            == vllm_config.compilation_config.splitting_ops
        )
        assert (
            sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
            == vllm_config.compilation_config.use_inductor_graph_partition
        )
        passes_for_backend: list[VllmInductorPass] = [
            noop_pass,
            sequence_parallelism_pass,
        ]
314

315
        if fuse_norm_quant:
316
317
            fusion_pass = RMSNormQuantFusionPass(vllm_config)
            passes_for_backend.append(fusion_pass)
318

319
        passes_for_backend.append(cleanup_pass)
320

321
        backend = TestBackend(*passes_for_backend)
322

323
        model = test_model_cls(hidden_size)
324

325
        hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
326

327
328
329
330
331
        if dynamic:
            torch._dynamo.mark_dynamic(hidden_states, 0)

        compiled_model = torch.compile(model, backend=backend)
        compiled_model(hidden_states)
332

333
        assert sequence_parallelism_pass.matched_count == 4
334

335
336
        # In pre-nodes, all reduce should be there,
        # reduce scatter and all gather should not
337
338
        for op in model.ops_in_model_before():
            assert backend.op_count(op, before=True) == 4
339

340
341
        # In post-nodes, reduce scatter and all gather should be there,
        # all reduce should not
342
343
        for op in model.ops_in_model_after():
            assert backend.op_count(op, before=False) == 4
344

345
        for op in model.ops_in_model():
346
            assert backend.op_count(op, before=False) > 0