"tests/vscode:/vscode.git/clone" did not exist on "9c271f94039341101aa144fccc857230a86ac575"
test_sequence_parallelism.py 10.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import pytest
import torch

import vllm.envs as envs
8
9
10
11
12
13
14
15
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, multi_gpu_test
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.passes.fx_utils import find_auto_fn
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
16
17
from vllm.config import (
    CompilationConfig,
18
    CUDAGraphMode,
19
20
21
22
    DeviceConfig,
    ModelConfig,
    PassConfig,
    VllmConfig,
23
24
    get_current_vllm_config,
    set_current_vllm_config,
25
)
26
from vllm.distributed import tensor_model_parallel_all_reduce
27
28
29
30
from vllm.distributed.parallel_state import (
    init_distributed_environment,
    initialize_model_parallel,
)
31
from vllm.model_executor.layers.layernorm import RMSNorm
32
33
34
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    kFp8StaticTensorSym,
)
35
from vllm.platforms import current_platform
36
from vllm.utils.system_utils import update_environment_variables
37
from vllm.utils.torch_utils import set_random_seed
38

39
FP8_DTYPE = current_platform.fp8_dtype()
40
41
42
43
44
45
46
47
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


48
49
class TestAllReduceRMSNormModel(torch.nn.Module):
    def __init__(self, hidden_size=16, eps=1e-6):
50
51
        super().__init__()
        self.hidden_size = hidden_size
52
53
54
        self.eps = eps
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
        self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
55

56
57
58
59
    def forward(self, x):
        z = torch.relu(x)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)
60

61
62
        z2 = torch.mm(y, self.w[0])
        x2 = tensor_model_parallel_all_reduce(z2)
63

64
        y2, resid = self.norm[1](x2, resid)
65

66
67
        z3 = torch.mm(y2, self.w[1])
        x3 = tensor_model_parallel_all_reduce(z3)
68

69
        y3, resid = self.norm[2](x3, resid)
70

71
72
        z4 = torch.mm(y3, self.w[2])
        x4 = tensor_model_parallel_all_reduce(z4)
73

74
75
        y4, resid = self.norm[3](x4, resid)
        return y4
76

77
78
79
80
81
    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [
82
            torch.ops.vllm.all_gather.default,
83
            torch.ops.vllm.reduce_scatter.default,
84
85
86
        ]

    def ops_in_model(self):
87
88
89
90
91
92
93
        if RMSNorm.enabled():
            return [
                torch.ops._C.rms_norm.default,
                torch.ops._C.fused_add_rms_norm.default,
            ]
        else:
            return []
94

95

96
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
97
98
    quant_key = kFp8StaticTensorSym

99
    def __init__(self, hidden_size=16, eps=1e-6):
100
        super().__init__()
101
        self.vllm_config = get_current_vllm_config()
102
103
104
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
105
106
107
108
109
110
111
        self.fp8_linear_layers = [
            TestFP8Layer(
                weight_shape=(hidden_size, hidden_size),
                activation_quant_key=self.quant_key,
                weight_quant_key=self.quant_key,
            )
            for i in range(3)
112
113
114
115
116
117
118
119
        ]

    def forward(self, hidden_states):
        # avoid having graph input be an arg to a pattern directly
        z = torch.relu(hidden_states)
        x = resid = tensor_model_parallel_all_reduce(z)
        y = self.norm[0](x)

120
        z2 = self.fp8_linear_layers[0](y)
121

122
123
        x2 = tensor_model_parallel_all_reduce(z2)
        y2, resid = self.norm[1](x2, resid)
124

125
        z3 = self.fp8_linear_layers[1](y2)
126
127
128
129

        x3 = tensor_model_parallel_all_reduce(z3)
        y3, resid = self.norm[2](x3, resid)  # use resid here

130
        z4 = self.fp8_linear_layers[2](y3)
131
132
133
        x4 = tensor_model_parallel_all_reduce(z4)
        y4, resid = self.norm[3](x4, resid)  # use resid here
        return y4
134
135

    def ops_in_model_after(self):
136
        return [
137
            torch.ops.vllm.all_gather.default,
138
139
140
141
142
143
            torch.ops.vllm.reduce_scatter.default,
        ]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
144
145
146
        ]

    def ops_in_model(self):
147
        if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
148
            return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
149
        elif RMSNorm.enabled():
150
151
152
            return [
                torch.ops._C.fused_add_rms_norm.default,
            ]
153
        elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers):
154
155
156
157
158
            return [
                torch.ops._C.static_scaled_fp8_quant.default,
            ]
        else:
            return []
159
160


161
@multi_gpu_test(num_gpus=2)
162
163
164
165
166
167
168
169
170
171
172
@pytest.mark.parametrize(
    "test_model_cls, custom_ops",
    [
        (TestAllReduceRMSNormModel, "+rms_norm"),
        (TestAllReduceRMSNormModel, "-rms_norm"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
        (TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
    ],
)
173
174
175
176
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
177
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
178
@pytest.mark.parametrize("dynamic", [False, True])
179
180
181
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass(
    test_model_cls: type[torch.nn.Module],
182
    custom_ops: str,
183
184
185
186
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
187
    fuse_norm_quant: bool,
188
    dynamic: bool,
189
):
190
191
192
193
194
    num_processes = 2

    def run_torch_spawn(fn, nprocs):
        # need to use torch.mp.spawn otherwise will have problems with
        # torch.distributed and cuda
195
196
197
198
199
        torch.multiprocessing.spawn(
            fn,
            args=(
                num_processes,
                test_model_cls,
200
                custom_ops,
201
202
203
204
                batch_size,
                seq_len,
                hidden_size,
                dtype,
205
                fuse_norm_quant,
206
                dynamic,
207
208
209
            ),
            nprocs=nprocs,
        )
210
211
212
213

    run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)


214
def sequence_parallelism_pass_on_test_model(
215
216
217
    local_rank: int,
    world_size: int,
    test_model_cls: type[torch.nn.Module],
218
    custom_ops: str,
219
220
221
222
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    dtype: torch.dtype,
223
    fuse_norm_quant: bool,
224
    dynamic: bool,
225
):
226
    set_random_seed(0)
227
228
229
230
231
232

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

233
234
235
236
237
238
239
240
241
    update_environment_variables(
        {
            "RANK": str(local_rank),
            "LOCAL_RANK": str(local_rank),
            "WORLD_SIZE": str(world_size),
            "MASTER_ADDR": "localhost",
            "MASTER_PORT": "12345",
        }
    )
242
243
244
245
246
247

    # initialize distributed
    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

    # configure vllm config for SequenceParallelismPass
248
    custom_ops_list = custom_ops.split(",") if custom_ops else []
249
    compilation_config = CompilationConfig(
250
251
252
        splitting_ops=[],  # avoid automatic rms_norm enablement
        cudagraph_mode=CUDAGraphMode.NONE,  # avoid piecewise warnings
        custom_ops=custom_ops_list,
253
        pass_config=PassConfig(
254
255
256
            enable_sp=True,
            fuse_norm_quant=fuse_norm_quant,
            eliminate_noops=True,
257
        ),
258
    )  # NoOp needed for fusion
259
    device_config = DeviceConfig(device=torch.device("cuda"))
260
261
262

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
263
    model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
264
    model_config = ModelConfig(
265
266
        model=model_name, trust_remote_code=True, dtype=dtype, seed=42
    )
267

268
269
270
271
    vllm_config = VllmConfig(
        model_config=model_config,
        device_config=device_config,
        compilation_config=compilation_config,
272
    )
273

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    with set_current_vllm_config(vllm_config):
        noop_pass = NoOpEliminationPass(vllm_config)
        sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
        cleanup_pass = PostCleanupPass(vllm_config)
        assert (
            sequence_parallelism_pass.compilation_config.splitting_ops
            == vllm_config.compilation_config.splitting_ops
        )
        assert (
            sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
            == vllm_config.compilation_config.use_inductor_graph_partition
        )
        passes_for_backend: list[VllmInductorPass] = [
            noop_pass,
            sequence_parallelism_pass,
        ]
290

291
        if fuse_norm_quant:
292
293
            fusion_pass = RMSNormQuantFusionPass(vllm_config)
            passes_for_backend.append(fusion_pass)
294

295
        passes_for_backend.append(cleanup_pass)
296

297
        backend = TestBackend(*passes_for_backend)
298

299
        model = test_model_cls(hidden_size)
300

301
        hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
302

303
304
305
306
307
        if dynamic:
            torch._dynamo.mark_dynamic(hidden_states, 0)

        compiled_model = torch.compile(model, backend=backend)
        compiled_model(hidden_states)
308

309
        assert sequence_parallelism_pass.matched_count == 4
310

311
312
        # In pre-nodes, all reduce should be there,
        # reduce scatter and all gather should not
313
314
        for op in model.ops_in_model_before():
            assert backend.op_count(op, before=True) == 4
315

316
317
        # In post-nodes, reduce scatter and all gather should be there,
        # all reduce should not
318
319
        for op in model.ops_in_model_after():
            assert backend.op_count(op, before=False) == 4
320

321
        for op in model.ops_in_model():
322
            find_auto_fn(backend.graph_post_pass.nodes, op)