rlhf.py 4.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
a simple demonstration of RLHF with vLLM, inspired by
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
It follows the design that, training processes and inference processes
are different, and they live on different GPUs.
Training processes send prompts to inference processes to generate data,
and also synchronize the weights of the model by broadcasting the weights
from the training process to the inference process.
Note that this is a simple demonstration of one training instance and one
inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer
to the OpenRLHF framework.
"""
16

17
18
19
20
21
22
import os

import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
23
from rlhf_utils import stateless_init_process_group
24
25
from transformers import AutoModelForCausalLM

26
from vllm import LLM, SamplingParams
27
28
29
30
31
32
33
34
from vllm.utils import get_ip, get_open_port


class MyLLM(LLM):
    def __init__(self, *args, **kwargs):
        # a hack to make the script work.
        # stop ray from manipulating CUDA_VISIBLE_DEVICES
        # at the top-level
35
        os.environ.pop("CUDA_VISIBLE_DEVICES", None)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        super().__init__(*args, **kwargs)


"""
Start the training process, here we use huggingface transformers 
as an example to hold a model on GPU 0.
"""

train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
"""
Start the inference process, here we use vLLM to hold a model on GPU 1 and 
GPU 2. For the details on how to use ray, please refer to the ray 
documentation https://docs.ray.io/en/latest/ .
"""
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()

pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
    placement_group=pg_inference,
    placement_group_capture_child_tasks=True,
    placement_group_bundle_index=0,
)
"""
launch the vLLM inference engine.
here we use `enforce_eager` to reduce the start time.
"""
llm = ray.remote(
    num_cpus=0,
    num_gpus=0,
    scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
    model="facebook/opt-125m",
    enforce_eager=True,
72
    worker_extension_cls="rlhf_utils.WorkerExtension",
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    tensor_parallel_size=2,
    distributed_executor_backend="ray",
)

# Generate texts from the prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

sampling_params = SamplingParams(temperature=0)

outputs = ray.get(llm.generate.remote(prompts, sampling_params))

89
print("-" * 50)
90
91
92
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
93
    print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
94
    print("-" * 50)
95
96
97
98
99
100

# set up the communication between the training process
# and the inference engine.
master_address = get_ip()
master_port = get_open_port()

101
102
103
handle = llm.collective_rpc.remote(
    "init_weight_update_group", args=(master_address, master_port, 1, 3)
)
104

105
106
107
model_update_group = stateless_init_process_group(
    master_address, master_port, 0, 3, torch.device("cuda:0")
)
108
109
110
111
112
113
114
115
ray.get(handle)

# simulate training, modify the weights of the model.
for name, p in train_model.named_parameters():
    p.data.zero_()

# sync weight from the training process to the inference engine.
for name, p in train_model.named_parameters():
116
    handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
117
118
119
120
121
122
123
124
125
    model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
    ray.get(handle)

# check if the weights are updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))

# use the updated model to generate texts, they will be nonsense
# because the weights are all zeros.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
126
print("-" * 50)
127
128
129
for output in outputs_updated:
    prompt = output.prompt
    generated_text = output.outputs[0].text
130
    print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
131
    print("-" * 50)