single_controller.rst 12.6 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
The Design of ``verl.single_controller``
==============================================

Last updated: 05/21/2025.

**Author:**\  `Wang Zhang <https://github.com/zw0610>`__

Preface
-------

We prepared this document for developers of ``verl``, particularly those
interested in understanding or contributing to the
``verl.single_controller`` module. It is not intended for end users, but
for contributors seeking to understand the architectural rationale and
internal mechanics.

--------------

Origin
------

The ``single_controller`` module originated from a request I received —
to adapt a toy single-process RLHF script into a distributed system with
minimal changes, while maintaining ease of debugging.

Common practice — such as using PyTorch’s Distributed Data Parallel
(DDP) — typically involves wrapping ``nn.Module`` and launching multiple
processes that execute the same function under different ranks. However,
this approach presents two main limitations in the context of
distributed RLHF: - Difficulty representing multiple DAGs as required by
PPO; - Difficulty inspecting intermediate tensors during training.

To maintain debuggability, we opted for a different approach — breaking
the training loop into well-defined stages like ``generate_sequences``,
``compute_advantages``, and so on.

We selected `Ray <https://www.ray.io/>`__ as the initial backend for
``verl`` due to its ability to expose Python class methods as RPC
endpoints. However, Ray’s default model only supports **one method call,
one RPC**, while training LLMs typically requires coordination across
multiple processes.

To hide this multi-Ray actors invocation for a single method from users,
we introduced the following components:

-  ``WorkerGroup`` – manages a group of remote workers and provides
   a unified interface for multi-process distributed computation;
-  ``ResourcePool`` – binds computational resources to worker
   processes;
-  ``ClassWithArgs`` – enables delayed remote instantiation with
   specified initialization arguments.

--------------

A Running Example: ``generate_sequences``
-----------------------------------------

To illustrate the design, we walk through how the ``generate_sequences``
method in the ``ActorRolloutRefWorker`` class is registered and invoked
across distributed workers.

--------------

Step 1: Register with a Decorator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The first step is to define the ``generate_sequences`` and decorate it
with ``@register`` as it will be called in driver script.

**Source:**
`fsdp_workers.py <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/workers/fsdp_workers.py#L528>`__

.. code:: python

   class ActorRolloutRefWorker(Worker):
       ...
       @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
       def generate_sequences(self, prompts: DataProto):
           prompts = prompts.to(torch.cuda.current_device())
           ...

The ``@register`` decorator adds metadata to the ``generate_sequences``
method. Currently, it doesn’t alter functionality, but attaches
attributes via a magic key (``MAGIC_ATTR``):

**Source:**
`decorator.py <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L411>`__

.. code:: python

   def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
       ...
       def decorator(func):
           @wraps(func)
           def inner(*args, **kwargs):
               if materialize_futures:
                   args, kwargs = _materialize_futures(*args, **kwargs)
               return func(*args, **kwargs)

           attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
           setattr(inner, MAGIC_ATTR, attrs)
           return inner

       return decorator

As the code shows, values of ``dispatch_mode``, ``execute_mode`` and
``blocking`` is attached the ``generate_sequences`` method.

--------------

Step 2: Binding During Initialization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

These attached attributes are extracted and utilized when
``ActorRolloutRefWorker``, wrapped in a ``RayClassWithArgs``, is passed
into a ``RayWorkerGroup``.

**Source:**
`main_generation.py <https://github.com/volcengine/verl/blob/4ae9a0fdab229f75f080e9478807783ed4c97154/verl/trainer/main_generation.py#L82>`__

.. code:: python

   ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
   resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
   wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)

During the
`initialization <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L184>`__
of ``RayWorkerGroup``, two key steps occur:

1. Worker instances (Ray actors) are created:
   `RayWorkerGroup._init_with_resource_pool <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L211>`__
2. Methods decorated with ``@register`` are bound to ``RayWorkerGroup``:
   `RayWorkerGroup._bind_worker_method <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L214>`__

.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/worker_group_init.png?raw=true
   :alt: initialization_and_binding_of_worker_group

   initialization_and_binding_of_worker_group

The binding procedure is the heart of ``verl.single_controller``.

**Key function:**
`WorkerGroup._bind_worker_method <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/worker_group.py#L143>`__

.. code:: python

   def _bind_worker_method(self, user_defined_cls, func_generator):
       ...
       for method_name in dir(user_defined_cls):
           try:
               method = getattr(user_defined_cls, method_name)
               assert callable(method)
           except Exception:
               continue  # Skip properties
           <<<to be continue 1>>>

When a method has the ``MAGIC_ATTR``, the attributes set by
``@register`` are extracted:

.. code:: python

           <<<continue 1>>>
           if hasattr(method, MAGIC_ATTR):
               attribute = getattr(method, MAGIC_ATTR)
               dispatch_mode = attribute["dispatch_mode"]
               execute_mode = attribute["execute_mode"]
               blocking = attribute["blocking"]

               <<<to be continue 2>>>

As show in the flow chart above, these attributes are fed into
``func_generator``. However, ``func_generator`` takes ``method_name``,
``dispatch_fn``, ``collect_fn``, ``execute_fn``, ``blocking``. We need
to find the corresponding ``dispatch_fn`` and ``collect_fn`` associated
with the ``dispatch_mode`` (``DP_COMPUTE_PROTO``) from
`DISPATCH_MODE_FN_REGISTRY <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L387>`__:

.. code:: python3

   DISPATCH_MODE_FN_REGISTRY = {
       Dispatch.ONE_TO_ALL: {
           "dispatch_fn": dispatch_one_to_all,
           "collect_fn": collect_all_to_all,
       },
       ...
       Dispatch.DP_COMPUTE_PROTO: {
           "dispatch_fn": dispatch_dp_compute_data_proto,
           "collect_fn": collect_dp_compute_data_proto,
       },
       ...
   }

Similarly, the ``execute_fn`` is selected by ``execute_mode`` and
extracted by:

.. code:: python

               <<<continue 2>>>
               # get execute_fn_name
               execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
               wg_execute_fn_name = execute_mode["execute_fn_name"]

               # get execute_fn from string
               try:
                   execute_fn = getattr(self, wg_execute_fn_name)
                   assert callable(execute_fn), "execute_fn must be callable"
               except Exception:
                   print(f"execute_fn {wg_execute_fn_name} is invalid")
                   raise
               <<<to be continue 3>>>

In this ``generate_sequences`` cases: -
``dispatch_mode = Dispatch.DP_COMPUTE_PROTO`` -
``dispatch_fn = dispatch_dp_compute_data_proto`` -
``collect_fn = collect_dp_compute_data_proto`` -
``execute_fn = RayWorkerGroup.execute_all``

ONE_TO_ALL v.s. DP_COMPUTE_PROTO
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``dispatch_mode`` is associated with a ``dispatch_fn`` and a
``collect_fn``. As the name implies, ``dispatch_fn`` processes the input
arguments in ``WorkerGroup`` and generate a batch (list) of input
arguments, each of which will be fed into a worker attached to the
``WorkerGroup``.

``dispatch_fn`` of ``ONE_TO_ALL`` is
`dispatch_one_to_all <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L119>`__,
which just duplicates all the input arguments into N replicas, where N
equals the number of Workers attached to the ``worker_group``:

.. code:: python

   def dispatch_one_to_all(worker_group, *args, **kwargs):
       args = tuple([arg] * worker_group.world_size for arg in args)
       kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
       return args, kwargs

``dispatch_fn`` of ``DP_COMPUTE_PROTO`` is
`dispatch_dp_compute_data_proto <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L350>`__,
which uses ``DataProto.chunk`` to split a large ``DataProto`` into N
smaller ``DataProto``, where N equals the world_size (number of the
workers) of the ``worker_group``:

.. code:: python

   def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
       from verl.single_controller.base.worker_group import WorkerGroup

       assert isinstance(worker_group, WorkerGroup)
       # Note: enable auto padding for dp compute DatapProto
       splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(
           worker_group.world_size,
           *args,
           **kwargs,
       )
       return splitted_args, splitted_kwargs

The ``collect_fn`` follows the same pattern and process a batch (list)
of returned value from all workers of a ``WorkerGroup`` and merge it
into a list as ``collect_all_to_all`` does or a large ``DataProto`` as
``collect_dp_compute_data_proto`` does.

Finally, a new method is dynamically generated using ``func_generator``
and added to the ``WorkerGroup`` instance:

.. code:: python

               <<<continue 3>>>
               # bind a new method to the RayWorkerGroup
               func = func_generator(
                   self,
                   method_name,
                   dispatch_fn=dispatch_fn,
                   collect_fn=collect_fn,
                   execute_fn=execute_fn,
                   blocking=blocking,
               )

               try:
                   setattr(self, method_name, func)
                   method_names.append(method_name)
               except Exception as e:
                   raise ValueError(f"Fail to set method_name {method_name}") from e

This makes the method invocable via the ``WorkerGroup`` interface.

--------------

Step 3: Call Chain
~~~~~~~~~~~~~~~~~~

All the machinery above ensures that distributed calls feel identical to
single-process ones. In the original single-process script, the code
looks like:

.. code:: python

   rollout = Rollout()
   rollout.generate_sequences(batch)

With ``verl``, the multiprocess program becomes:

.. code:: python

   rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout))
   rollout.generate_sequences(batch)

.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/call_generate_sequences.png?raw=true
   :alt: call_chain_of_generate_sequences

   call_chain_of_generate_sequences

Behind this simple call: - ``dispatch_fn`` splits input across workers -
``execute_fn`` performs the actual remote invocation - ``collect_fn``
gathers the results

All of this is abstracted away, enabling developers to write distributed
code with minimal changes to their existing logic.

--------------

Beyond RL Post-Training: Generalizing ``verl.single_controller``
----------------------------------------------------------------

The ``verl.single_controller`` module generalizes well beyond
reinforcement learning. It provides a clean abstraction to batch-process
remote method calls, with automatic input/output handling.

By minimizing the gap between single-process and multi-process scripts,
``verl.single_controller`` opens the door to distributed computing in
broader domains — not limited to RL post-training.

We hope this design inspires more examples and extensions from the
community.