Commit f87b35b2 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2648 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# setup.py is the fallback installation script when pyproject.toml does not work
from setuptools import setup, find_packages
import os
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
with open(os.path.join(version_folder, 'verl/version/version')) as f:
__version__ = f.read().strip()
install_requires = [
'accelerate',
'codetiming',
'datasets',
'dill',
'hydra-core',
'numpy',
'pandas',
'datasets',
'peft',
'pyarrow>=15.0.0',
'pybind11',
'pylatexenc',
'ray[default]>=2.10',
'tensordict<=0.6.2',
'torchdata',
'transformers',
'wandb',
]
TEST_REQUIRES = ['pytest', 'yapf', 'py-spy']
PRIME_REQUIRES = ['pyext']
GEO_REQUIRES = ['mathruler']
GPU_REQUIRES = ['liger-kernel', 'flash-attn']
MATH_REQUIRES = ['math-verify'] # Add math-verify as an optional dependency
VLLM_REQUIRES = ['tensordict<=0.6.2', 'vllm<=0.8.2']
SGLANG_REQUIRES = [
'tensordict<=0.6.2',
'sglang[all]==0.4.4.post4',
'torch-memory-saver>=0.0.5'
]
extras_require = {
'test': TEST_REQUIRES,
'prime': PRIME_REQUIRES,
'geo': GEO_REQUIRES,
'gpu': GPU_REQUIRES,
'math': MATH_REQUIRES,
'vllm': VLLM_REQUIRES,
'sglang': SGLANG_REQUIRES,
}
from pathlib import Path
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()
setup(
name='verl',
version=__version__,
package_dir={'': '.'},
packages=find_packages(where='.'),
url='https://github.com/volcengine/verl',
license='Apache 2.0',
author='Bytedance - Seed - MLSys',
author_email='zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk',
description='verl: Volcano Engine Reinforcement Learning for LLM',
install_requires=install_requires,
extras_require=extras_require,
package_data={'': ['version/*'],
'verl': ['trainer/config/*.yaml'],},
include_package_data=True,
long_description=long_description,
long_description_content_type='text/markdown'
)
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import shutil
import torch
import copy
import torch.distributed
from torch.distributed import init_device_mesh
from verl.utils.distributed import initialize_global_process_group
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Qwen2Config
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \
CPUOffload
def test_fsdp_ckpt():
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
local_rank, rank, world_size = initialize_global_process_group()
device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))
model_name = 'Qwen/Qwen2.5-0.5B-Instruct'
config = Qwen2Config(num_hidden_layers=1)
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
model = model.to(device='cuda')
# Wrap model with FSDP
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
model = FSDP(model,
use_orig_params=False,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=device_mesh)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
# Create checkpoint manager
tokenizer = AutoTokenizer.from_pretrained(model_name)
checkpoint_manager = FSDPCheckpointManager(model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
tokenizer=tokenizer)
# Generate sample input
batch_size = 2
seq_len = 32
vocab_size = 32000
# First input for initial update
input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
attention_mask1 = torch.ones_like(input_ids1)
# Second input for verification
input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
attention_mask2 = torch.ones_like(input_ids2)
# Step 1: Initial update and save checkpoint
outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1)
loss1 = outputs1.logits.mean()
loss1.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Save checkpoint after first update
temp_dir = tempfile.mkdtemp()
checkpoint_path = os.path.join(temp_dir, 'checkpoint')
checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)
# Step 2: Second update and forward pass
outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2)
loss2 = outputs2.logits.mean()
loss2.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Record logits after second update
with torch.no_grad():
logits_before_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
# Step 3: Load checkpoint and repeat second update
checkpoint_manager.load_checkpoint(checkpoint_path)
# Repeat the second update with same input
outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2)
loss3 = outputs3.logits.mean()
loss3.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Record logits after loaded checkpoint and update
with torch.no_grad():
logits_after_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits
# Step 4: Verify outputs match
torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0)
print("Checkpoint save/load test passed!")
# Cleanup
shutil.rmtree(temp_dir)
torch.distributed.barrier()
if __name__ == '__main__':
test_fsdp_ckpt()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env bash
set -e -x
torchrun --nproc-per-node=4 --standalone tests/distributed/test_tensor_dict.py
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['NCCL_DEBUG'] = 'WARN'
from verl.protocol import all_gather_data_proto, DataProto
from verl.utils.distributed import initialize_global_process_group
import torch
import torch.distributed
import numpy as np
def test_all_gather_data_proto():
device_mesh = torch.distributed.device_mesh.init_device_mesh('cuda', mesh_shape=[2, 2], mesh_dim_names=['dp', 'tp'])
global_rank = torch.distributed.get_rank()
obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]])
labels = ['a', 'b'] if global_rank % 2 == 0 else ['b', 'a']
labels = np.array(labels, dtype=object)
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
all_gather_data_proto(data=data, process_group=device_mesh.get_group('dp'))
if global_rank == 0:
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device='cuda')
expected_labels = ['a', 'b', 'a', 'b']
elif global_rank == 1:
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device='cuda')
expected_labels = ['b', 'a', 'b', 'a']
elif global_rank == 2:
expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device='cuda')
expected_labels = ['a', 'b', 'a', 'b']
elif global_rank == 3:
expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device='cuda')
expected_labels = ['b', 'a', 'b', 'a']
torch.testing.assert_close(data.batch['obs'], expected_obs, atol=0, rtol=0)
assert (data.non_tensor_batch['labels'] == expected_labels).all()
assert data.meta_info == {'info': 'test_info'}
def test_vocab_parallel_entropy():
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.torch_functional import entropy_from_logits
from megatron.core import parallel_state as mpu
mpu.initialize_model_parallel(tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None)
batch_size = 2
seqlen = 128
vocab_size = 155136
logits = torch.randn(batch_size * seqlen, vocab_size, device='cuda', requires_grad=True)
target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device='cuda', dtype=torch.int64)
# broadcast across tp
torch.distributed.broadcast(logits,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(target,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
tp_rank = mpu.get_tensor_model_parallel_rank()
vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size()
# get the local logits of each tp
vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) *
vocab_size_per_tp].requires_grad_()
logits.grad = None
vocab_parallel_logits.grad = None
log_gpu_memory_usage('begin')
output_entropy = vocab_parallel_entropy(vocab_parallel_logits)
log_gpu_memory_usage('after forward')
grad_output = torch.randn_like(output_entropy)
output_entropy.backward(grad_output)
log_gpu_memory_usage('after backward')
target_entropy = entropy_from_logits(logits)
torch.testing.assert_close(output_entropy, target_entropy)
target_entropy.backward(grad_output)
torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp],
vocab_parallel_logits.grad)
# make sure logits is not altered
torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp],
vocab_parallel_logits)
if mpu.get_tensor_model_parallel_rank() == 0:
print('test_vocab_parallel_entropy passes')
mpu.destroy_model_parallel()
if __name__ == '__main__':
local_rank, rank, world_size = initialize_global_process_group()
test_all_gather_data_proto()
test_vocab_parallel_entropy()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.e2e.envs.digit_completion import DigitCompletion, generate_ground_truth_response
from torch.utils import data
import os
if __name__ == '__main__':
simple_task = DigitCompletion(max_number=9, max_diff=9, max_num_in_response=9)
all_prompts = simple_task.get_all_prompts()
# 21 * 6 * 4
train_data, test_data = data.random_split(all_prompts, lengths=[0.8, 0.2])
train_data = list(train_data)
test_data = list(test_data)
train_data = [[{'role': 'user', 'content': str(item)}] \
for item in train_data]
test_data = [[{'role': 'user', 'content': str(item)}] \
for item in test_data]
print(f'Size of train: {len(train_data)}, size of test: {len(test_data)}')
train_data = {'prompt': train_data}
test_data = {'prompt': test_data}
model_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)))
import pandas as pd
train_data_frame = pd.DataFrame(train_data)
test_data_frame = pd.DataFrame(test_data)
train_data_frame.to_parquet(os.path.join(model_folder, 'train.parquet'))
test_data_frame.to_parquet(os.path.join(model_folder, 'test.parquet'))
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": null,
"eos_token_id": 1,
"hidden_act": "silu",
"hidden_size": 128,
"initializer_range": 0.02,
"intermediate_size": 344,
"max_position_embeddings": 2048,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 4,
"num_hidden_layers": 4,
"num_key_value_heads": 4,
"pad_token_id": 2,
"pretraining_tp": 1,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.3",
"use_cache": true,
"vocab_size": 16
}
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create a random model and tokenizer for PPO training
"""
import torch
import os
from transformers import AutoModelForCausalLM, LlamaConfig, AutoTokenizer
from tests.e2e.envs.digit_completion import CharTokenizer
tokenizer = CharTokenizer(
characters=['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ',', ':'],
model_max_length=2048,
chat_template=
"{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}"
)
config = LlamaConfig(vocab_size=(tokenizer.vocab_size + 16 - 1) // 16 * 16,
hidden_size=128,
intermediate_size=344,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id)
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
model_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)))
os.makedirs(model_folder, exist_ok=True)
model.save_pretrained(model_folder)
tokenizer_folder = model_folder
tokenizer.save_pretrained(tokenizer_folder)
load_tokenizer = AutoTokenizer.from_pretrained(tokenizer_folder)
chat = [{'role': 'user', 'content': '1,0:2,3'}]
load_tokenizer.padding_side = 'left'
print(
load_tokenizer.apply_chat_template(chat,
tokenize=True,
add_generation_prompt=True,
max_length=10,
padding='max_length'))
{
"_from_model_config": true,
"eos_token_id": 1,
"pad_token_id": 2,
"transformers_version": "4.43.3"
}
{
"char_ords": [
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
44,
58
],
"model_max_length": 2048,
"chat_template": "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = message['role'] %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ sep_token }}{% endif %}"
}
\ No newline at end of file
# Digit completion
This is an example of solving a digit completion problem. The problem is defined as below:
The prompt is a sequence of numbers with fixed difference. The agent's goal is to complete the next N numbers.
If the max number is reached, the next number should be modulo with max number.
For example,
- prompt = [1, 2, 3]
- N = 5
- max_number = 6
The response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1].
# Environment definition
The core definition of the task is defined in verl/envs/digit_completion/task.py
It is highly recommended to take a look at it for better understanding.
# Run experiments
The users are required to specify the config path and config name (and the relative model config path to the current working directory)
```bash
# cd examples/arithmetic_sequence/rl
# Specify the config path and config name (current working dir)
python3 -m verl.trainer.ppo.ray_megatron_train_synchronous --config-path=$(pwd)/config --config-name='ray_megatron'
# The default relative path of model config is 'config/model_config', if you want to change it, you can rewrite it in ray_megatron.yaml or using:
python3 -m verl.trainer.ppo.ray_megatron_train_synchronous --config-path=$(pwd)/config --config-name='ray_megatron' ++model.base_path=config/model_config
```
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Using FSDPTrainer
"""
import os
import hydra
import ray
import torch
from transformers import PreTrainedTokenizer, AutoTokenizer
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.utils.fs import copy_to_local
from tests.e2e.envs.digit_completion import CharTokenizer
def make_reward_function(tokenizer, num_examine):
def arithmetic_sequence_reward_function(data: DataProto, return_dict: bool = False):
from tests.e2e.envs.digit_completion.task import compute_reward
reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
for i in range(data.batch.batch_size[0]):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch['prompts']
prompt_length = prompt_ids.shape[-1]
# extract raw prompt
valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
# extract response
response_ids = data_item.batch['responses']
response_length = response_ids.shape[-1]
response_mask = data.batch['attention_mask'][i][-response_length:]
valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
prompt = tokenizer.decode(valid_prompt_ids)
response = tokenizer.decode(valid_response_ids)
# remove bos and eos
prompt = prompt.replace(tokenizer.sep_token, '')
response = response.replace(tokenizer.eos_token, '')
if i < num_examine:
print(prompt, response)
reward_output = compute_reward(prompt, response)
dense_reward = reward_output[0].tolist()
ground_truth_response = reward_output[1]['ground_truth_response']
if len(dense_reward) > 0:
last_reward = dense_reward[-1]
else:
if len(ground_truth_response) == 0:
last_reward = 1
else:
last_reward = 0
# pad to response_length
for _ in range(reward_tensor.shape[-1] - len(dense_reward)):
dense_reward.append(last_reward)
dense_reward = torch.as_tensor(dense_reward, dtype=torch.float32, device=reward_tensor.device)
reward_tensor[i] = dense_reward * response_mask
if return_dict:
return {"reward_tensor": reward_tensor}
else:
return reward_tensor
return arithmetic_sequence_reward_function
@hydra.main(config_path='../../../../verl/trainer/config', config_name='ppo_trainer', version_base=None)
def main(config):
ray.init(
runtime_env={
'env_vars': {
'MEGATRON_USE_CUDA_TIMER': '0',
'MEGATRON_START_PROCESS_TIMER': 'False',
'TOKENIZERS_PARALLELISM': 'true',
'NCCL_DEBUG': 'WARN'
}
})
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
# print the config
# print initial config
print('Config after normalizing batch_size')
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
local_path = os.path.expanduser(local_path)
# instantiate tokenizern
tokenizer = AutoTokenizer.from_pretrained(local_path)
print(f'Tokenizer vocab_size: {tokenizer.vocab_size}')
# define worker classes
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = 'global_pool'
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}
# use reward model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
reward_fn = make_reward_function(tokenizer=tokenizer, num_examine=1)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
reward_fn=reward_fn,
val_reward_fn=reward_fn)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def check_congratulations_in_file(output_file):
with open(output_file, 'r') as f:
output = f.read()
success_message = "Congratulations!!! You have called my_reward_function successfully!!!"
assert success_message in output, f'Success message of my_reward_function not found in {output_file}'
print("Check passes")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', required=True, type=str)
args = parser.parse_args()
check_congratulations_in_file(args.output_file)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
def extract_reward_from_line(line):
# TODO: this function needs error handling
try:
key_vals = line.split(' - ')
for key_val in key_vals:
key, val = key_val.split(':')
if key == 'critic/rewards/mean':
reward = float(val)
return reward
return -np.inf
except Exception:
return -np.inf
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', required=True, type=str)
parser.add_argument('--target', type=float, default=0.2, help='target reward score')
args = parser.parse_args()
with open(args.output_file, 'r') as f:
output = f.read().split('\n')
best_reward = -np.inf
for line in output:
if line.startswith('step'):
reward = extract_reward_from_line(line)
if reward > best_reward:
best_reward = reward
print(f'Best reward is {best_reward}')
assert best_reward > args.target, f'Best reward must be greater than {args.target}. best_reward: {best_reward}'
print('Check passes')
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .digit_completion import DigitCompletion
__all__ = ['DigitCompletion']
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .task import DigitCompletion, generate_ground_truth_response
from .tokenizer import CharTokenizer
from transformers import AutoTokenizer, LlamaConfig
AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True)
__all__ = ['DigitCompletion', 'generate_ground_truth_response', 'CharTokenizer']
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task and environment definition for digit completion."""
import numpy as np
class DigitCompletion(object):
"""
The implementation of a simple digit completion task.
The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers.
If the max number is reached, the next number should be modulo with max number.
For example,
- prompt = [1, 2, 3]
- N = 5
- max_number = 6
the response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1]
Note that the tokenizer is char-level to increase the difficulty.
"""
def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0):
"""
Args:
max_number: the maximum number allowed in the arithmetic sequence
max_diff: the maximum diff. The actual common diff will be sampled from [0, max_diff]
max_num_in_response: the maximum number in the response
"""
super().__init__()
self.max_number = max_number
self.max_diff = max_diff
self.max_num_in_response = max_num_in_response
assert self.max_num_in_response < 10
assert self.max_number > 0
assert self.max_diff > 0
self.max_number_length = len(str(max_number))
# {num1},{num2}:{max_num_in_response},{max_number}
self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length # no negative is allowed
self.np_rng = np.random.default_rng(seed=seed)
def __str__(self):
return f'Prompt length: {self.prompt_length}. Response length: {self.response_length}, ' \
f'Max number: {self.max_number}. Max diff: {self.max_diff}, ' \
f'Max number in response: {self.max_num_in_response}'
def get_state(self):
return {'rng': self.np_rng}
def set_state(self, state):
assert 'rng' in state, 'rng must be inside state'
self.np_rng = state['rng']
@property
def prompt_length(self):
return self._prompt_length
@property
def response_length(self):
# number length + comma length + [EOS]
# The actual number times 1.5 to allow 'U'
return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2
def add(self, a, b):
return (a + b) % self.max_number
def get_all_prompts(self):
all_prompts = []
for first_num in range(self.max_number + 1):
for diff in range(0, self.max_diff + 1):
second_num = self.add(first_num, diff)
for num_to_complete in range(self.max_num_in_response + 1):
prompt = str(first_num) + ',' + str(second_num) + f':{self.max_number},{num_to_complete}'
all_prompts.append(prompt)
return all_prompts
def sample_str_prompts(self):
# step 1: sample initial numbers
first_num = self.np_rng.integers(self.max_number + 1)
diff = self.np_rng.integers(self.max_diff + 1)
second_num = self.add(first_num, diff)
num_to_complete = self.np_rng.integers(self.max_num_in_response + 1)
prompt = str(first_num) + ',' + str(second_num) + f':{self.max_number},{num_to_complete}'
return prompt
def sample_batch_str_prompts(self, batch_size):
str_prompts = []
for _ in range(batch_size):
str_prompts.append(self.sample_str_prompts())
return str_prompts
def compute_attention_mask(prompts, pad_token_id):
mask = np.ones_like(prompts)
mask[prompts == pad_token_id] = 0
return mask
def compute_position_id_with_mask(mask):
return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None)
def generate_ground_truth_response(prompt: str):
"""Generate ground truth response given a prompt."""
num, info = prompt.split(':')
num1, num2 = num.split(',')
max_number, num_to_gen = info.split(',')
num1 = int(num1)
num2 = int(num2)
max_number = int(max_number)
num_to_gen = int(num_to_gen)
diff = (num2 - num1) % max_number
results = []
last_num = num2
for _ in range(num_to_gen):
curr = (last_num + diff) % max_number
results.append(str(curr))
last_num = curr
response = ','.join(results)
return response
def compute_reward(prompt: str, response: str, sequence_reward=1.):
"""We compute dense reward here so that we can directly train RL without SFT"""
response_length = len(response)
ground_truth_response = generate_ground_truth_response(prompt)
per_token_reward = sequence_reward / (len(ground_truth_response) + 1) # including [EOS]
# pad
reward = np.zeros(response_length, dtype=np.float32) # this assumes that each char is a token
# assign reward until mismatches
ground_truth_idx = 0
for i in range(response_length):
if ground_truth_idx == len(ground_truth_response):
break
ground_truth_response_token = ground_truth_response[ground_truth_idx]
response_token = response[i]
if ground_truth_response_token == response_token:
reward[i] = per_token_reward
ground_truth_idx += 1
else:
# no matches
break
return reward, {'ground_truth_response': ground_truth_response}
if __name__ == '__main__':
task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5)
print(task.sample_str_prompts())
prompt = '7,8:20,0'
response = ''
print(compute_reward(prompt, response))
prompt = '7,8:20,0'
response = 'E000'
print(compute_reward(prompt, response))
prompt = '9,10:20,2'
response = '11,12,13'
print(compute_reward(prompt, response))
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Copied from https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py
CharacterTokenzier for Hugging Face Transformers.
This is heavily inspired from CanineTokenizer in transformers package.
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
class CharTokenizer(PreTrainedTokenizer):
def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs):
"""Character tokenizer for Hugging Face transformers.
Args:
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following are list of all of the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
model_max_length (int): Model maximum sequence length.
"""
eos_token_str = 'E'
sep_token_str = 'S'
pad_token_str = 'P'
unk_token_str = 'U'
self.characters = characters
self.model_max_length = model_max_length
eos_token = AddedToken(eos_token_str, lstrip=False, rstrip=False)
sep_token = AddedToken(sep_token_str, lstrip=False, rstrip=False)
pad_token = AddedToken(pad_token_str, lstrip=False, rstrip=False)
unk_token = AddedToken(unk_token_str, lstrip=False, rstrip=False)
self._vocab_str_to_int = {
sep_token_str: 0,
eos_token_str: 1,
pad_token_str: 2,
unk_token_str: 3,
**{
ch: i + 4 for i, ch in enumerate(characters)
},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
super().__init__(
eos_token=eos_token,
sep_token=sep_token,
pad_token=pad_token,
unk_token=unk_token,
add_prefix_space=False,
model_max_length=model_max_length,
**kwargs,
)
self.chat_template = chat_template
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def get_vocab(self):
return self._vocab_str_to_int
def _tokenize(self, text: str) -> List[str]:
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["U"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
def build_inputs_with_special_tokens(self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = cls + token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = [1] + ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def get_config(self) -> Dict:
return {
"char_ords": [ord(ch) for ch in self.characters],
"model_max_length": self.model_max_length,
"chat_template": self.chat_template
}
@classmethod
def from_config(cls, config: Dict) -> "DigitCompletionTokenizer":
cfg = {}
cfg["characters"] = [chr(i) for i in config["char_ords"]]
cfg["model_max_length"] = config["model_max_length"]
cfg["chat_template"] = config["chat_template"]
return cls(**cfg)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
cfg_file = Path(save_directory) / "tokenizer_config.json"
cfg = self.get_config()
with open(cfg_file, "w") as f:
json.dump(cfg, f, indent=4)
@classmethod
def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
cfg_file = Path(save_directory) / "tokenizer_config.json"
with open(cfg_file) as f:
cfg = json.load(f)
return cls.from_config(cfg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment