Commit 7f6cc211 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2874 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test for using ray collective group.
Suppose we Actor and Rollout. Actor contains 4 workers and Rollout contains 2 workers. We established a Worker to
Rollout relationship by using collective groups
Actor: rank 0, 1 - Rollout rank 0
Rollout rank 2, 3 - Rollout rank 1
Then, we initiate 4 p2p comms from actor to rollout
"""
import ray
import ray.util.collective as collective
import torch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@ray.remote
class Actor(Worker):
@register(Dispatch.ONE_TO_ALL)
def init(self):
remote_rank = self.rank // 2
self.group_name = f"A{self.rank}_R{remote_rank}"
collective.init_collective_group(world_size=2, rank=0, backend="nccl", group_name=self.group_name)
@register(Dispatch.ONE_TO_ALL, blocking=False)
def send_tensors(self):
tensor = torch.ones(size=(4,), dtype=torch.float32, device="cuda") * self.rank
collective.send(tensor=tensor, dst_rank=1, group_name=self.group_name)
@ray.remote
class Rollout(Worker):
@register(Dispatch.ONE_TO_ALL)
def init(self):
self.remote_first_rank = self.rank * 2
self.remote_second_rank = self.remote_first_rank + 1
self.first_group_name = f"A{self.remote_first_rank}_R{self.rank}"
self.second_group_name = f"A{self.remote_second_rank}_R{self.rank}"
collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.first_group_name)
collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.second_group_name)
@register(Dispatch.ONE_TO_ALL, blocking=False)
def receive_tensors(self):
self.tensor1 = torch.randn(size=(4,), dtype=torch.float32, device="cuda")
self.tensor2 = torch.randn(size=(4,), dtype=torch.float32, device="cuda")
collective.recv(self.tensor1, src_rank=0, group_name=self.first_group_name)
collective.recv(self.tensor2, src_rank=0, group_name=self.second_group_name)
@register(Dispatch.ONE_TO_ALL)
def get_tensors(self):
return {f"src_{self.remote_first_rank}": self.tensor1, f"src_{self.remote_second_rank}": self.tensor2}
def test_ray_collective_group():
ray.init()
actor_resource_pool = RayResourcePool([4])
rollout_resource_pool = RayResourcePool([2])
actor_cls = RayClassWithInitArgs(cls=Actor)
rollout_cls = RayClassWithInitArgs(cls=Rollout)
actor_wg = RayWorkerGroup(
resource_pool=actor_resource_pool, ray_cls_with_init=actor_cls, name_prefix="collective_group_actor"
)
rollout_wg = RayWorkerGroup(
resource_pool=rollout_resource_pool, ray_cls_with_init=rollout_cls, name_prefix="collective_group_rollout"
)
actor_wg.init()
rollout_wg.init()
out1 = actor_wg.send_tensors()
out2 = rollout_wg.receive_tensors()
# block to wait
ray.get(out1)
ray.get(out2)
output = rollout_wg.get_tensors()
rollout_0_output = output[0]
rollout_1_output = output[1]
output = rollout_0_output | rollout_1_output
print(output)
for i in range(4):
assert torch.sum(output[f"src_{i}"]).item() == 4 * i
ray.shutdown()
if __name__ == "__main__":
test_ray_collective_group()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment