test_verl_engine_4_gpu.py 9.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import multiprocessing
import multiprocessing as mp
import os
import random
import traceback
import unittest
from multiprocessing import Process

import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.api import (
    ShardedStateDictConfig,
    ShardingStrategy,
    StateDictType,
)
from transformers import AutoModelForCausalLM

from sglang.srt.entrypoints.verl_engine import VerlEngine
from sglang.srt.utils import is_port_available
23
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
24
25
26
27
28
29
from sglang.test.runners import (
    HFRunner,
    SRTRunner,
    check_close_model_outputs,
    get_dtype_str,
)
30
from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci
31
32
33
34
35
36
37
38
39
40

_MAX_NEW_TOKENS = 8
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
_TORCH_DTYPE = torch.float16

# Set to false to temporarily debug issues unrelated to weight update
_ENABLE_UPDATE_WEIGHTS = True
# _ENABLE_UPDATE_WEIGHTS = False

# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
41
42
43
44
45
46
ALL_MODELS = [
    dict(
        model_path="Qwen/Qwen2.5-0.5B",
        dp_size=2,
        tp_size=2,  # default to 2
    ),
47
48
    dict(
        model_path="Qwen/Qwen2.5-14B-Instruct",
49
50
51
        mem_fraction_static=0.7,
        dp_size=2,
        tp_size=2,
52
53
54
55
56
        tight_memory=True,
        decode_tolerance=1.3,
    ),  # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
    dict(
        model_path="THUDM/glm-4-9b-chat",
57
58
59
        mem_fraction_static=0.5,
        dp_size=2,
        tp_size=2,
60
61
62
63
64
65
66
67
        tight_memory=True,
    ),
    # Fail to run these models in test_generation_models.py, need to fix that first
    # dict(model_path="openai-community/gpt2"),
    # dict(model_path="microsoft/Phi-3-small-8k-instruct"),
]


68
class TestVerlEngine(CustomTestCase):
69
70
71
72
73
74
75
76
77
    @classmethod
    def setUpClass(cls):
        multiprocessing.set_start_method("spawn")

    def assert_fragment_e2e_execution(
        self,
        index: int,
        model_path: str,
        mem_fraction_static: float = 0.4,
78
        dp_size: int = 1,
79
80
81
82
83
84
85
86
87
88
89
        tp_size: int = 2,
        tight_memory: bool = False,
        prefill_tolerance: float = 0.1,
        decode_tolerance: float = 0.1,
    ):
        master_port = find_available_port(23456)

        print(f"assert_fragment_e2e_execution START {index=} {model_path=}")

        processes = []
        output_reader, output_writer = mp.Pipe(duplex=False)
90
91
        world_size = dp_size * tp_size
        for rank in range(world_size):
92
93
94
            p = Process(
                target=_run_subprocess,
                kwargs=dict(
95
96
                    rank=rank,
                    dp_size=dp_size,
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
                    tp_size=tp_size,
                    master_port=master_port,
                    output_writer=output_writer,
                    model_path=model_path,
                    mem_fraction_static=mem_fraction_static,
                    tight_memory=tight_memory,
                    prefill_tolerance=prefill_tolerance,
                    decode_tolerance=decode_tolerance,
                ),
            )
            p.start()
            processes.append(p)

        for _ in range(tp_size):
            self.assertTrue(
                output_reader.recv(),
                f"Subprocess has error, please see logs above. ({index=} {model_path=})",
            )

        for p in processes:
            p.join()

    def test_ci_models(self):
120
121
        ci_models = [random.choice(ALL_MODELS)]
        for index, model_info in enumerate(ci_models):
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            self.assert_fragment_e2e_execution(index=index, **model_info)

    def test_others(self):
        if is_in_ci():
            return

        for index, model_info in enumerate(ALL_OTHER_MODELS):
            self.assert_fragment_e2e_execution(index=index, **model_info)

    # def test_adhoc(self):
    #     self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")


def _run_subprocess(
136
137
    rank: int,
    dp_size: int,
138
139
140
141
142
143
144
145
146
147
    tp_size: int,
    master_port: int,
    output_writer,
    model_path: str,
    mem_fraction_static: float,
    tight_memory: bool,
    prefill_tolerance: float,
    decode_tolerance: float,
):
    try:
148
        print(f"subprocess[{rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
149
150
151

        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(master_port)
152
153
154
155
        torch.distributed.init_process_group(rank=rank, world_size=dp_size * tp_size)
        torch.cuda.set_device(rank)

        base_gpu_id = rank // tp_size * tp_size
156

157
158
159
        mesh_kwargs = dict(
            mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
        )
160
161
162
        inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
        inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
        print(
163
            f"subprocess[{rank=},{base_gpu_id=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
164
165
166
167
168
169
170
171
172
        )

        # hf model is used for comparison
        hf_model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True
        ).cuda()
        hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)

        hf_outputs = HFRunner.forward_generation_raw(
173
            base_model=hf_model,
174
175
176
177
178
179
180
181
            prompts=_PROMPTS,
            max_new_tokens=_MAX_NEW_TOKENS,
            tokenizer=hf_tokenizer,
            lora_paths=None,
            torch_dtype=_TORCH_DTYPE,
            output_str_only=False,
        )
        print(
182
            f"subprocess[{rank=}] call hf.forward {hf_outputs=}",
183
184
185
186
187
188
189
190
191
            flush=True,
        )

        if _ENABLE_UPDATE_WEIGHTS:
            if tight_memory:
                hf_model.cpu()
                torch.cuda.empty_cache()

            # test update weights
192
193
194
195
            print(f"subprocess[{rank=}] get_fsdp_state_dict", flush=True)
            fsdp_state_dict = _get_fsdp_state_dict(
                hf_model=hf_model, world_size=dp_size * tp_size
            )
196
197
198
199
200
201

        engine = VerlEngine(
            model_path=model_path,
            load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
            mem_fraction_static=mem_fraction_static,
            random_seed=42,
202
            base_gpu_id=base_gpu_id,
203
204
205
206
            trust_remote_code=True,
            dtype=get_dtype_str(_TORCH_DTYPE),
            device_mesh_cpu=inference_device_mesh_cpu["tp"],
        )
207
        print(f"subprocess[{rank=}] {engine=}", flush=True)
208
209

        if _ENABLE_UPDATE_WEIGHTS:
210
            print(f"subprocess[{rank=}] call update_weights_from_tensor", flush=True)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            engine.update_weights_from_tensor(
                [(k, v) for k, v in fsdp_state_dict.items()]
            )

        for enable_batch in [False, True]:
            if enable_batch:
                fn = SRTRunner.batch_forward_generation_raw
            else:
                fn = SRTRunner.forward_generation_raw

            srt_outputs = fn(
                prompts=_PROMPTS,
                max_new_tokens=_MAX_NEW_TOKENS,
                lora_paths=None,
                engine=engine,
            )
            print(
228
                f"subprocess[{rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
229
230
231
232
233
234
235
236
237
238
                flush=True,
            )

            check_close_model_outputs(
                hf_outputs=hf_outputs,
                srt_outputs=srt_outputs,
                prefill_tolerance=prefill_tolerance,
                decode_tolerance=decode_tolerance,
                rouge_l_tolerance=1,
                check_logprobs=not enable_batch,
239
                debug_text=f"{enable_batch=} {rank=}",
240
241
242
243
244
            )

        execution_ok = True

    except Exception as e:
245
        print(f"subprocess[{rank=}] has error: {e}", flush=True)
246
247
248
249
250
251
        traceback.print_exc()
        execution_ok = False

    output_writer.send(execution_ok)
    output_writer.close()

252
253
    if "engine" in locals() and engine is not None:
        engine.shutdown()
254
    print(f"subprocess[{rank=}] end", flush=True)
255
256
257


# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
258
def _get_fsdp_state_dict(hf_model, world_size: int):
259
    device_mesh = init_device_mesh(
260
        "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    )

    mixed_precision = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.float32,
    )
    fsdp_model = FSDP(
        hf_model,
        use_orig_params=True,
        auto_wrap_policy=None,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        mixed_precision=mixed_precision,
        cpu_offload=CPUOffload(offload_params=False),
        sync_module_states=False,
        device_mesh=device_mesh,
    )
    print(f"{fsdp_model=}")

    FSDP.set_state_dict_type(
        fsdp_model,
        state_dict_type=StateDictType.SHARDED_STATE_DICT,
        state_dict_config=ShardedStateDictConfig(),
    )

    return fsdp_model.state_dict()


if __name__ == "__main__":
    unittest.main()