Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
verl_grpo
Commits
7f6cc211
Commit
7f6cc211
authored
Aug 05, 2025
by
jerrrrry
Browse files
Initial commit
parents
Pipeline
#2874
failed with stages
in 0 seconds
Changes
421
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
0 deletions
+113
-0
tests/single_controller/test_ray_collectives.py
tests/single_controller/test_ray_collectives.py
+113
-0
No files found.
Too many changes to show.
To preserve performance only
421 of 421+
files are displayed.
Plain diff
Email patch
tests/single_controller/test_ray_collectives.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test for using ray collective group.
Suppose we Actor and Rollout. Actor contains 4 workers and Rollout contains 2 workers. We established a Worker to
Rollout relationship by using collective groups
Actor: rank 0, 1 - Rollout rank 0
Rollout rank 2, 3 - Rollout rank 1
Then, we initiate 4 p2p comms from actor to rollout
"""
import
ray
import
ray.util.collective
as
collective
import
torch
from
verl.single_controller.base
import
Worker
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.ray
import
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
@
ray
.
remote
class
Actor
(
Worker
):
@
register
(
Dispatch
.
ONE_TO_ALL
)
def
init
(
self
):
remote_rank
=
self
.
rank
//
2
self
.
group_name
=
f
"A
{
self
.
rank
}
_R
{
remote_rank
}
"
collective
.
init_collective_group
(
world_size
=
2
,
rank
=
0
,
backend
=
"nccl"
,
group_name
=
self
.
group_name
)
@
register
(
Dispatch
.
ONE_TO_ALL
,
blocking
=
False
)
def
send_tensors
(
self
):
tensor
=
torch
.
ones
(
size
=
(
4
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
self
.
rank
collective
.
send
(
tensor
=
tensor
,
dst_rank
=
1
,
group_name
=
self
.
group_name
)
@
ray
.
remote
class
Rollout
(
Worker
):
@
register
(
Dispatch
.
ONE_TO_ALL
)
def
init
(
self
):
self
.
remote_first_rank
=
self
.
rank
*
2
self
.
remote_second_rank
=
self
.
remote_first_rank
+
1
self
.
first_group_name
=
f
"A
{
self
.
remote_first_rank
}
_R
{
self
.
rank
}
"
self
.
second_group_name
=
f
"A
{
self
.
remote_second_rank
}
_R
{
self
.
rank
}
"
collective
.
init_collective_group
(
world_size
=
2
,
rank
=
1
,
backend
=
"nccl"
,
group_name
=
self
.
first_group_name
)
collective
.
init_collective_group
(
world_size
=
2
,
rank
=
1
,
backend
=
"nccl"
,
group_name
=
self
.
second_group_name
)
@
register
(
Dispatch
.
ONE_TO_ALL
,
blocking
=
False
)
def
receive_tensors
(
self
):
self
.
tensor1
=
torch
.
randn
(
size
=
(
4
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
tensor2
=
torch
.
randn
(
size
=
(
4
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
collective
.
recv
(
self
.
tensor1
,
src_rank
=
0
,
group_name
=
self
.
first_group_name
)
collective
.
recv
(
self
.
tensor2
,
src_rank
=
0
,
group_name
=
self
.
second_group_name
)
@
register
(
Dispatch
.
ONE_TO_ALL
)
def
get_tensors
(
self
):
return
{
f
"src_
{
self
.
remote_first_rank
}
"
:
self
.
tensor1
,
f
"src_
{
self
.
remote_second_rank
}
"
:
self
.
tensor2
}
def
test_ray_collective_group
():
ray
.
init
()
actor_resource_pool
=
RayResourcePool
([
4
])
rollout_resource_pool
=
RayResourcePool
([
2
])
actor_cls
=
RayClassWithInitArgs
(
cls
=
Actor
)
rollout_cls
=
RayClassWithInitArgs
(
cls
=
Rollout
)
actor_wg
=
RayWorkerGroup
(
resource_pool
=
actor_resource_pool
,
ray_cls_with_init
=
actor_cls
,
name_prefix
=
"collective_group_actor"
)
rollout_wg
=
RayWorkerGroup
(
resource_pool
=
rollout_resource_pool
,
ray_cls_with_init
=
rollout_cls
,
name_prefix
=
"collective_group_rollout"
)
actor_wg
.
init
()
rollout_wg
.
init
()
out1
=
actor_wg
.
send_tensors
()
out2
=
rollout_wg
.
receive_tensors
()
# block to wait
ray
.
get
(
out1
)
ray
.
get
(
out2
)
output
=
rollout_wg
.
get_tensors
()
rollout_0_output
=
output
[
0
]
rollout_1_output
=
output
[
1
]
output
=
rollout_0_output
|
rollout_1_output
print
(
output
)
for
i
in
range
(
4
):
assert
torch
.
sum
(
output
[
f
"src_
{
i
}
"
]).
item
()
==
4
*
i
ray
.
shutdown
()
if
__name__
==
"__main__"
:
test_ray_collective_group
()
Prev
1
…
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment