test_torchrun_example_moe.py 2.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# unit test for `examples/offline_inference/torchrun_example.py`
import os
import random

import torch.distributed as dist

from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_tp_group, get_world_group

dist.init_process_group(backend="gloo")

# Create prompts
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
] * 10
dp_size = int(os.getenv("DP_SIZE", "1"))
dp_rank = int(os.getenv("DP_RANK", "0"))

if dp_size > 1:
    # distribute the prompts across the data parallel ranks
27
    prompts = [prompt for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank]
28
29
30

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

31
# set different `gpu_memory_utilization` for different ranks,
32
# to test if all ranks agree on the same kv cache configuration.
33
34
35
36
37
38
39
40
llm = LLM(
    model="microsoft/Phi-mini-MoE-instruct",
    tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
    pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
    enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
    distributed_executor_backend="external_launcher",
    gpu_memory_utilization=random.uniform(0.7, 0.9),
    seed=0,
41
42
    max_model_len=1024,
    max_num_seqs=16,
43
)
44
45
46
47
48
49
50
51
52
53
54
55
56

outputs = llm.generate(prompts, sampling_params)

group = get_world_group() if dp_size == 1 else get_tp_group()
cpu_group = group.cpu_group
group_rank = dist.get_rank(group=cpu_group)


def test_consistent_across_ranks(obj):
    if group_rank == 0:
        dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
    else:
        container = [None]
57
        dist.broadcast_object_list(container, src=group.ranks[0], group=cpu_group)
58
59
60
        assert container[0] == obj


61
62
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
63
64
65

# make sure we can access the model parameters from the calling process
# of the `LLM` instance.
66
67
68
params = list(
    llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters()
)
69
70
71
72
73
74
75
76
test_consistent_across_ranks(len(params))

# all ranks should have the same outputs
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    test_consistent_across_ranks(prompt)
    test_consistent_across_ranks(generated_text)
77
    print(f"Rank {group_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")