# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ In client, we can get the server handler and send RPC request """ import ray import torch from verl import DataProto from verl.single_controller.ray import RayClassWithInitArgs from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from tensordict import TensorDict from server import Trainer def compute_position_id_with_mask(mask): return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) if __name__ == '__main__': ray.init(address='auto', namespace='verl') # get the worker group using names worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1'] cls_with_init_args = RayClassWithInitArgs(cls=Trainer) worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args) batch_size = 16 sequence_length = 1024 # give Trainer some data to train input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device='cuda') attention_mask = torch.ones_like(input_ids) position_ids = compute_position_id_with_mask(attention_mask) data = DataProto(batch=TensorDict( { 'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids }, batch_size=batch_size), meta_info={}) output = worker_group.train_model(data) print(output)