test_fusion_all_reduce.py 9.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec

import pytest
import torch

import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass
10
11
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
12
13
14
15
16
17
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
                         ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
                                             initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm
18
19
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    GroupShape, QuantFP8)
20
21
22
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables

23
from ..utils import has_module_attribute, multi_gpu_test
24
25
26
27
28
from .backend import TestBackend


class TestAllReduceRMSNormModel(torch.nn.Module):

29
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = RMSNorm(hidden_size, eps)

    def forward(self, hidden_states, residual):
        view = hidden_states.reshape(-1, self.hidden_size)
        all_reduce = tensor_model_parallel_all_reduce(view)
        norm = self.norm(all_reduce)
        return norm

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):

50
    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = RMSNorm(hidden_size, eps)

    def forward(self, hidden_states, residual):
        view = hidden_states.reshape(-1, self.hidden_size)
        all_reduce = tensor_model_parallel_all_reduce(view)
        norm, _ = self.norm(all_reduce, residual)
        return norm

    def ops_in_model_before(self):
        return [torch.ops.vllm.all_reduce.default]

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):

    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = RMSNorm(hidden_size, eps)
        self.quant_fp8 = QuantFP8(static=True,
                                  group_shape=GroupShape.PER_TENSOR)
        self.scale = torch.rand(1, dtype=torch.float32)
        self.output = torch.empty((token_num, hidden_size),
                                  dtype=torch.float32)

    def forward(self, hidden_states, residual):
        view = hidden_states.reshape(-1, self.hidden_size)
        all_reduce = tensor_model_parallel_all_reduce(view)
        norm_output, residual_output = self.norm(all_reduce, residual)
        torch.ops._C.static_scaled_fp8_quant(self.output,
                                             norm_output.contiguous(),
                                             self.scale)
        return self.output, residual_output

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
            torch.ops._C.static_scaled_fp8_quant.default
        ]


class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):

    def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.norm = RMSNorm(hidden_size, eps)
        self.scale = torch.rand(1, dtype=torch.float32)
        self.output = torch.empty((token_num, hidden_size),
                                  dtype=torch.float32)

        round_up = lambda x, y: (x + y - 1) // y * y
        rounded_m = round_up(token_num, 128)
        scale_n = hidden_size // 16
        rounded_n = round_up(scale_n, 4)
        self.output_scale = torch.empty((rounded_m, rounded_n // 4),
                                        dtype=torch.int32)

    def forward(self, hidden_states, residual):
        view = hidden_states.reshape(-1, self.hidden_size)
        all_reduce = tensor_model_parallel_all_reduce(view)
        norm_output, residual_output = self.norm(all_reduce, residual)
        norm_output = norm_output.reshape(-1, norm_output.shape[-1])
        torch.ops._C.scaled_fp4_quant(self.output, norm_output,
                                      self.output_scale, self.scale)
        return self.output, residual_output, self.output_scale

    def ops_in_model_after(self):
        return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

    def ops_in_model_before(self):
        return [
            torch.ops.vllm.all_reduce.default,
            torch.ops._C.scaled_fp4_quant.default
        ]


138
@multi_gpu_test(num_gpus=2)
139
140
141
142
143
144
@pytest.mark.parametrize("test_model", [
    TestAllReduceRMSNormModel,
    TestAllReduceFusedAddRMSNormModel,
    TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
    TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
])
145
146
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
147
@pytest.mark.parametrize("hidden_size", [16])
148
149
150
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
                    reason="Only test on CUDA")
151
152
153
154
155
@pytest.mark.skipif(
    not find_spec("flashinfer")
    or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
    reason="flashinfer is not found or flashinfer "
    "is not compiled with trtllm_allreduce_fusion")
156
157
158
159
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
                                        batch_size: int, seq_len: int,
                                        hidden_size: int, dtype: torch.dtype):
    num_processes = 2
160
161
162
163
    if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
            and not current_platform.has_device_capability(100)):
        pytest.skip("Skip as nvfp4 is only supported on "
                    "devices with compute capability 10.0 (Blackwell)")
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

    def run_torch_spawn(fn, nprocs):
        torch.multiprocessing.spawn(fn,
                                    args=(num_processes, test_model,
                                          batch_size, seq_len, hidden_size,
                                          dtype),
                                    nprocs=nprocs)

    run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)


def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
                                         test_model_cls: torch.nn.Module,
                                         batch_size: int, seq_len: int,
                                         hidden_size: int, dtype: torch.dtype):
    current_platform.seed_everything(0)

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)

    update_environment_variables({
        'RANK': str(local_rank),
        'LOCAL_RANK': str(local_rank),
        'WORLD_SIZE': str(world_size),
        'MASTER_ADDR': 'localhost',
        'MASTER_PORT': '12345',
    })

    init_distributed_environment()
    initialize_model_parallel(tensor_model_parallel_size=world_size)

197
198
199
    vllm_config = VllmConfig(compilation_config=CompilationConfig(
        level=CompilationLevel.PIECEWISE,
        custom_ops=["+rms_norm", "+quant_fp8"]))
200
    vllm_config.compilation_config.pass_config = PassConfig(
201
        enable_fi_allreduce_fusion=True, enable_noop=True)
202
203
204
205
206
207
208
209
210
211
    vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))

    # this is a fake model name to construct the model config
    # in the vllm_config, it's not really used.
    model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
    vllm_config.model_config = ModelConfig(model=model_name,
                                           trust_remote_code=True,
                                           dtype=dtype,
                                           seed=42)

212
    all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
213
214
215
216
    noop_pass = NoOpEliminationPass(vllm_config)
    func_pass = FixFunctionalizationPass(vllm_config)

    backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
217

218
219
    token_num = batch_size * seq_len
    model = test_model_cls(hidden_size, token_num)
220

221
222
    hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
    residual = torch.randn((token_num, hidden_size), requires_grad=False)
223
224
225
226
227
228
229

    compiled_model = torch.compile(model, backend=backend)
    compiled_model(hidden_states, residual)

    backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
    backend.check_after_ops(model.ops_in_model_after())
    del all_reduce_fusion_pass