test_sequence_parallelism.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import pytest
import torch

import vllm.envs as envs
8
9
10
11
12
13
14
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, multi_gpu_test
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
15
16
from vllm.config import (
    CompilationConfig,
17
    CUDAGraphMode,
18
19
20
21
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
22
23
    get_current_vllm_config,
    set_current_vllm_config,
24
)
25
from vllm.distributed import tensor_model_parallel_all_reduce
26
27
28
29
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
30
from vllm.model_executor.layers.layernorm import RMSNorm
31
32
33
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
34
from vllm.platforms import current_platform
35
from vllm.utils.system_utils import update_environment_variables
36
from vllm.utils.torch_utils import set_random_seed
37

38
39
DEVICE_TYPE = current_platform.device_type

40
41
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")

42
FP8_DTYPE = current_platform.fp8_dtype()
43
44
45
46
47
48
49
50
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


51
52
class TestAllReduceRMSNormModel(torch.nn.Module):
    def __init__(self, hidden_size=16, eps=1e-6):
53
54
        super().__init__()
        self.hidden_size = hidden_size
55
56
57
        self.eps = eps
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
        self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
58

59
60
61
62
    def forward(self, x):
        z = torch.relu(x)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)
63

64
65
        z2 = torch.mm(y, self.w[0])
        x2 = tensor_model_parallel_all_reduce(z2)
66

67
        y2, resid = self.norm[1](x2, resid)
68

69
70
        z3 = torch.mm(y2, self.w[1])
        x3 = tensor_model_parallel_all_reduce(z3)
71

72
        y3, resid = self.norm[2](x3, resid)
73

74
75
        z4 = torch.mm(y3, self.w[2])
        x4 = tensor_model_parallel_all_reduce(z4)
76

77
78
        y4, resid = self.norm[3](x4, resid)
        return y4
79

80
81
82
83
84
    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [
85
            torch.ops.vllm.all_gather.default,
86
            torch.ops.vllm.reduce_scatter.default,
87
88
89
        ]

    def ops_in_model(self):
90
91
92
        return (
            [torch.ops.vllm_ir.rms_norm]
            + [
93
94
                torch.ops._C.fused_add_rms_norm.default,
            ]
95
96
97
            if RMSNorm.enabled()
            else []
        )
98

99

100
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
101
102
    quant_key = kFp8StaticTensorSym

103
    def __init__(self, hidden_size=16, eps=1e-6):
104
        super().__init__()
105
        self.vllm_config = get_current_vllm_config()
106
107
108
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
109
110
111
112
113
        self.fp8_linear_layers = [
            TestFP8Layer(
                weight_shape=(hidden_size, hidden_size),
                activation_quant_key=self.quant_key,
                weight_quant_key=self.quant_key,
114
                input_dtype=self.vllm_config.model_config.dtype,
115
116
            )
            for i in range(3)
117
118
119
120
121
122
123
124
        ]

    def forward(self, hidden_states):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(hidden_states)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)

125
        z2 = self.fp8_linear_layers[0](y)
126

127
128
        x2 = tensor_model_parallel_all_reduce(z2)
        y2, resid = self.norm[1](x2, resid)
129

130
        z3 = self.fp8_linear_layers[1](y2)
131
132
133
134

        x3 = tensor_model_parallel_all_reduce(z3)
        y3, resid = self.norm[2](x3, resid)  # use resid here

135
        z4 = self.fp8_linear_layers[2](y3)
136
137
138
        x4 = tensor_model_parallel_all_reduce(z4)
        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
139
140

    def ops_in_model_after(self):
141
        return [
142
            torch.ops.vllm.all_gather.default,
143
144
145
146
147
148
            torch.ops.vllm.reduce_scatter.default,
        ]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
149
150
151
        ]

    def ops_in_model(self):
152
        if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
153
            return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
154
        elif RMSNorm.enabled():
155
156
157
            return [
                torch.ops._C.fused_add_rms_norm.default,
            ]
158
        elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers):
159
160
161
162
163
            return [
                torch.ops._C.static_scaled_fp8_quant.default,
            ]
        else:
            return []
164
165


166
@multi_gpu_test(num_gpus=2)
167
168
169
170
171
172
173
174
175
176
177
@pytest.mark.parametrize(
    "test_model_cls, custom_ops",
    [
        (TestAllReduceRMSNormModel, "+rms_norm"),
        (TestAllReduceRMSNormModel, "-rms_norm"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
    ],
)
178
179
180
181
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
182
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
183
@pytest.mark.parametrize("dynamic", [False, True])
184
185
186
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
    test_model_cls: type[torch.nn.Module],
187
    custom_ops: str,
188
189
190
191
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
192
    fuse_norm_quant: bool,
193
    dynamic: bool,
194
):
195
196
197
198
199
    num_processes = 2

    def run_torch_spawn(fn, nprocs):
        # need to use torch.mp.spawn otherwise will have problems with
        # torch.distributed and cuda
200
201
202
203
204
        torch.multiprocessing.spawn(
            fn,
            args=(
                num_processes,
                test_model_cls,
205
                custom_ops,
206
207
208
209
                batch_size,
                seq_len,
                hidden_size,
                dtype,
210
                fuse_norm_quant,
211
                dynamic,
212
213
214
            ),
            nprocs=nprocs,
        )
215
216
217
218

    run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)


219
def sequence_parallelism_pass_on_test_model(
220
221
222
    local_rank: int,
    world_size: int,
    test_model_cls: type[torch.nn.Module],
223
    custom_ops: str,
224
225
226
227
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
228
    fuse_norm_quant: bool,
229
    dynamic: bool,
230
):
231
    set_random_seed(0)
232

233
    device = torch.device(f"{DEVICE_TYPE}:{local_rank}")
234
    torch.accelerator.set_device_index(device)
235
236
237
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

238
239
240
241
242
243
244
245
246
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
247
248
249
250
251

    # initialize distributed
    init_distributed_environment()

    # configure vllm config for SequenceParallelismPass
252
    custom_ops_list = custom_ops.split(",") if custom_ops else []
253
    compilation_config = CompilationConfig(
254
255
256
        splitting_ops=[],  # avoid automatic rms_norm enablement
        cudagraph_mode=CUDAGraphMode.NONE,  # avoid piecewise warnings
        custom_ops=custom_ops_list,
257
        pass_config=PassConfig(
258
259
260
            enable_sp=True,
            fuse_norm_quant=fuse_norm_quant,
            eliminate_noops=True,
261
        ),
262
    )  # NoOp needed for fusion
263
    device_config = DeviceConfig(device=torch.device(DEVICE_TYPE))
264
265
266

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
267
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
268
    model_config = ModelConfig(
269
270
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
271

272
273
274
275
    vllm_config = VllmConfig(
        model_config=model_config,
        device_config=device_config,
        compilation_config=compilation_config,
276
    )
277

278
    with set_current_vllm_config(vllm_config):
279
        initialize_model_parallel(tensor_model_parallel_size=world_size)
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        noop_pass = NoOpEliminationPass(vllm_config)
        sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
        assert (
            sequence_parallelism_pass.compilation_config.splitting_ops
            == vllm_config.compilation_config.splitting_ops
        )
        assert (
            sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
            == vllm_config.compilation_config.use_inductor_graph_partition
        )
        passes_for_backend: list[VllmInductorPass] = [
            noop_pass,
            sequence_parallelism_pass,
        ]
295

296
        if fuse_norm_quant:
297
298
            fusion_pass = RMSNormQuantFusionPass(vllm_config)
            passes_for_backend.append(fusion_pass)
299

300
        passes_for_backend.append(cleanup_pass)
301

302
        backend = TestBackend(*passes_for_backend)
303

304
        model = test_model_cls(hidden_size)
305

306
        hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
307

308
309
310
311
312
        if dynamic:
            torch._dynamo.mark_dynamic(hidden_states, 0)

        compiled_model = torch.compile(model, backend=backend)
        compiled_model(hidden_states)
313

314
        assert sequence_parallelism_pass.matched_count == 4
315

316
317
        # In pre-nodes, all reduce should be there,
        # reduce scatter and all gather should not
318
319
        for op in model.ops_in_model_before():
            assert backend.op_count(op, before=True) == 4
320

321
322
        # In post-nodes, reduce scatter and all gather should be there,
        # all reduce should not
323
324
        for op in model.ops_in_model_after():
            assert backend.op_count(op, before=False) == 4
325

326
        for op in model.ops_in_model():
327
            assert backend.op_count(op, before=False) > 0