Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
...@@ -12,11 +12,21 @@ ...@@ -12,11 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import re
from setuptools import find_packages, setup from setuptools import find_packages, setup
def get_requires(): def get_version() -> str:
with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"__version__\W*=\W*\"([^\"]+)\""
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
...@@ -31,18 +41,19 @@ extra_require = { ...@@ -31,18 +41,19 @@ extra_require = {
def main(): def main():
setup( setup(
name="verl", name="verl",
version="0.2.0.dev0", version=get_version(),
package_dir={"": "."}, description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL",
packages=find_packages(where="."), long_description=open("README.md", encoding="utf-8").read(),
url="https://github.com/volcengine/verl", long_description_content_type="text/markdown",
license="Apache 2.0",
author="verl", author="verl",
author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn", author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
description="", license="Apache 2.0 License",
url="https://github.com/volcengine/verl",
package_dir={"": "."},
packages=find_packages(where="."),
python_requires=">=3.9.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require, extras_require=extra_require,
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
) )
......
...@@ -12,8 +12,4 @@ ...@@ -12,8 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .protocol import DataProto
__all__ = ["DataProto"]
__version__ = "0.2.0.dev" __version__ = "0.2.0.dev"
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .transformers.flash_attention_utils import flash_attention_forward
from .transformers.qwen2_vl import qwen2_vl_attn_forward
def apply_ulysses_patch(model_type: str) -> None:
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
else:
raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
from ...utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
_flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def prepare_fa2_from_position_ids(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seqlens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
def _custom_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
query_length: int,
is_causal: bool = True,
position_ids: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
deterministic: Optional[bool] = None,
**kwargs,
):
"""
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if _flash_supports_deterministic:
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
if kwargs.get("softcap") is not None:
flash_kwargs["softcap"] = kwargs.pop("softcap")
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype=torch.bfloat16
)
sp_size = get_ulysses_sequence_parallel_world_size()
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids[0]
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=kwargs.pop("dropout", 0.0),
softmax_scale=kwargs.pop("softmax_scale", None),
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_length,
is_causal=is_causal,
sliding_window=sliding_window,
use_top_left_mask=use_top_left_mask,
deterministic=deterministic,
**kwargs,
) # do not pass position_ids to old flash_attention_forward
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
return attn_output
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# This is before the transpose
q_len = query.shape[2]
# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
attn_output = _custom_flash_attention_forward(
query,
key,
value,
attention_mask,
query_length=q_len,
is_causal=True,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_flash_use_top_left_mask,
**kwargs,
)
return attn_output, None
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Based on:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from .flash_attention_utils import flash_attention_forward
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLAttention,
apply_multimodal_rotary_pos_emb,
repeat_kv,
)
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
except ImportError:
pass
def get_rope_index(
processor: "Qwen2VLProcessor",
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
video_grid_thw: Optional[torch.Tensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
"""
spatial_merge_size = processor.image_processor.merge_size
tokens_per_second = 2
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
image_index, video_index = 0, 0
input_ids = input_ids[attention_mask == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
else:
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
return position_ids
def qwen2_vl_attn_forward(
self: "Qwen2VLAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, None, None]:
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
attn_output, _ = flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=dropout_rate,
sliding_window=sliding_window,
position_ids=position_ids, # important: pass position ids
) # (batch_size, seq_length, num_head / sp_size, head_size)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None, None
...@@ -26,13 +26,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -26,13 +26,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import ray import ray
import torch import torch
import torch.distributed as dist
from numpy.typing import NDArray from numpy.typing import NDArray
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from verl.utils.py_functional import union_two_dict from .utils.py_functional import union_two_dict
try: try:
...@@ -89,21 +88,22 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten ...@@ -89,21 +88,22 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
) )
for key, value in tensor_dict2.items(): for key in tensor_dict2.keys():
if key in tensor_dict1 and not torch.equal(tensor_dict1[key], value): if key in tensor_dict1 and not torch.equal(tensor_dict1[key], tensor_dict2[key]):
raise ValueError(f"Key already exists: {key}.") raise ValueError(f"Key already exists: {key}.")
tensor_dict1[key] = value tensor_dict1[key] = tensor_dict2[key]
return tensor_dict1 return tensor_dict1
def union_numpy_dict( def union_numpy_dict(tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray]) -> Dict[str, NDArray]:
tensor_dict1: Dict[str, Union[List, NDArray]], tensor_dict2: Dict[str, Union[List, NDArray]] for key in tensor_dict2.keys():
) -> Dict[str, Union[List, NDArray]]: if key in tensor_dict1:
for key, value in tensor_dict2.items(): assert isinstance(tensor_dict2[key], np.ndarray)
if key in tensor_dict1 and isinstance(value, np.ndarray) and not np.all(tensor_dict1[key] == value): assert isinstance(tensor_dict1[key], np.ndarray)
raise ValueError(f"Key already exists: {key}.") if not np.all(tensor_dict1[key] == tensor_dict2[key]):
raise ValueError(f"Key already exists: {key}.")
tensor_dict1[key] = tensor_dict2[key] tensor_dict1[key] = tensor_dict2[key]
...@@ -151,6 +151,7 @@ def collate_fn(data_items: list["DataProtoItem"]): ...@@ -151,6 +151,7 @@ def collate_fn(data_items: list["DataProtoItem"]):
batch = torch.stack(batch).contiguous() batch = torch.stack(batch).contiguous()
non_tensor_batch = batch_collate(non_tensor_batch) non_tensor_batch = batch_collate(non_tensor_batch)
non_tensor_batch = {key: np.array(value, dtype=object) for key, value in non_tensor_batch.items()}
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
...@@ -189,7 +190,8 @@ class DataProto: ...@@ -189,7 +190,8 @@ class DataProto:
def __getitem__(self, item): def __getitem__(self, item):
tensor_data = self.batch[item] tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self): def __getstate__(self):
buffer = io.BytesIO() buffer = io.BytesIO()
...@@ -229,8 +231,7 @@ class DataProto: ...@@ -229,8 +231,7 @@ class DataProto:
size_of_numpy_array = 0 size_of_numpy_array = 0
for value in self.non_tensor_batch.values(): for value in self.non_tensor_batch.values():
if isinstance(value, np.ndarray): size_of_numpy_array += value.nbytes
size_of_numpy_array += value.nbytes
size_of_numpy_array /= 1024**3 size_of_numpy_array /= 1024**3
size_of_tensordict /= 1024**3 size_of_tensordict /= 1024**3
...@@ -254,13 +255,13 @@ class DataProto: ...@@ -254,13 +255,13 @@ class DataProto:
assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}." assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}."
@classmethod @classmethod
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, list, NDArray]], meta_info=None): def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, NDArray]], meta_info=None):
tensors = {} tensors = {}
non_tensors = {} non_tensors = {}
for key, value in data.items(): for key, value in data.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
tensors[key] = value tensors[key] = value
elif isinstance(value, (list, np.ndarray)): elif isinstance(value, np.ndarray):
non_tensors[key] = value non_tensors[key] = value
else: else:
raise ValueError(f"Unsupported type in data {type(value)}") raise ValueError(f"Unsupported type in data {type(value)}")
...@@ -472,8 +473,6 @@ class DataProto: ...@@ -472,8 +473,6 @@ class DataProto:
assert len(self) % chunks == 0, ( assert len(self) % chunks == 0, (
f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}."
) )
chunk_size = len(self) // chunks
if self.batch is not None: if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0) batch_lst = self.batch.chunk(chunks=chunks, dim=0)
else: else:
...@@ -481,12 +480,8 @@ class DataProto: ...@@ -481,12 +480,8 @@ class DataProto:
non_tensor_batch_lst = [{} for _ in range(chunks)] non_tensor_batch_lst = [{} for _ in range(chunks)]
for key, value in self.non_tensor_batch.items(): for key, value in self.non_tensor_batch.items():
assert isinstance(value, (list, np.ndarray)) assert isinstance(value, np.ndarray)
if isinstance(value, np.ndarray): non_tensor_lst = np.array_split(value, chunks)
non_tensor_lst = np.array_split(value, chunks)
else:
non_tensor_lst = [value[i : i + chunk_size] for i in range(0, len(self), chunk_size)]
assert len(non_tensor_lst) == chunks assert len(non_tensor_lst) == chunks
for i in range(chunks): for i in range(chunks):
non_tensor_batch_lst[i][key] = non_tensor_lst[i] non_tensor_batch_lst[i][key] = non_tensor_lst[i]
...@@ -524,12 +519,7 @@ class DataProto: ...@@ -524,12 +519,7 @@ class DataProto:
non_tensor_batch = batch_collate([d.non_tensor_batch for d in data]) non_tensor_batch = batch_collate([d.non_tensor_batch for d in data])
for key, value in non_tensor_batch.items(): for key, value in non_tensor_batch.items():
if isinstance(value[0], np.ndarray): non_tensor_batch[key] = np.concatenate(value, axis=0)
non_tensor_batch[key] = np.concatenate(value, axis=0)
else:
non_tensor_batch[key] = []
for item in value:
non_tensor_batch[key].extend(item)
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
...@@ -574,16 +564,10 @@ class DataProto: ...@@ -574,16 +564,10 @@ class DataProto:
repeated_non_tensor_batch = {} repeated_non_tensor_batch = {}
for key, value in self.non_tensor_batch.items(): for key, value in self.non_tensor_batch.items():
if isinstance(value, np.ndarray): if interleave:
if interleave: repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
else:
repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
else: else:
if interleave: repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
repeated_non_tensor_batch[key] = [item for item in value for _ in range(repeat_times)]
else:
repeated_non_tensor_batch[key] = [item for _ in range(repeat_times) for item in value]
return DataProto( return DataProto(
batch=repeated_batch, batch=repeated_batch,
...@@ -591,39 +575,6 @@ class DataProto: ...@@ -591,39 +575,6 @@ class DataProto:
meta_info=self.meta_info, meta_info=self.meta_info,
) )
def broadcast(self, src: int, group: Optional[ProcessGroup] = None):
for key in self.batch.sorted_keys:
dist.broadcast(self.batch[key], src=src, group=group, async_op=False)
object_list = [self.non_tensor_batch]
dist.broadcast_object_list(object_list, src=src, group=group)
self.non_tensor_batch = object_list[0]
def all_gather(self, group: Optional[ProcessGroup] = None):
world_size = dist.get_world_size(group)
output = {}
for key in self.batch.sorted_keys:
value = self.batch[key].contiguous()
output[key] = [torch.empty_like(value) for _ in range(world_size)]
dist.all_gather(output[key], value, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=0)
self.batch = TensorDict(output, batch_size=self.batch.batch_size[0] * world_size)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(world_size)]
dist.all_gather_object(all_non_tensor_batch, self.non_tensor_batch, group=group)
non_tensor_batch = defaultdict(list)
for key, value in self.non_tensor_batch.items():
if isinstance(value, np.ndarray):
non_tensor_batch[key] = np.concatenate([batch[key] for batch in all_non_tensor_batch])
else:
for batch in all_non_tensor_batch:
non_tensor_batch[key].extend(batch[key])
self.non_tensor_batch = non_tensor_batch
self.check_consistency()
@dataclass @dataclass
class DataProtoFuture: class DataProtoFuture:
...@@ -664,10 +615,53 @@ class DataProtoFuture: ...@@ -664,10 +615,53 @@ class DataProtoFuture:
return arg_future_lst return arg_future_lst
def get(self): def get(self):
output = ray.get(self.futures) # dp_size. outputs = ray.get(self.futures) # dp_size.
for o in output: for output in outputs:
assert isinstance(o, DataProto) assert isinstance(output, DataProto)
output = self.collect_fn(output) # select dp, concat
outputs = self.collect_fn(outputs) # select dp, concat
if self.dispatch_fn is not None: if self.dispatch_fn is not None:
output = self.dispatch_fn(output) # split in batch dim, select using dp outputs = self.dispatch_fn(outputs) # split in batch dim, select using dp
return output
return outputs
def allgather_dict_tensors(
tensors: Union[Dict[str, torch.Tensor], TensorDict], size: int, group: ProcessGroup, dim: int = 0
) -> Union[Dict[str, torch.Tensor], TensorDict]:
"""
TODO: optimize this.
- We can use async ops
- We can use only one allgather
"""
if isinstance(tensors, TensorDict):
is_tensor_dict = True
tensors_as_dict = tensors.to_dict()
else:
tensors_as_dict = tensors
is_tensor_dict = False
output = {}
sorted_keys = sorted(tensors_as_dict.keys())
for key in sorted_keys:
val = tensors_as_dict[key]
output[key] = [torch.empty_like(val) for _ in range(size)]
torch.distributed.all_gather(output[key], val, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=dim)
if is_tensor_dict:
output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)
return output
def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> None:
# Note that this is an inplace operator just like torch.distributed.all_gather
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
...@@ -14,3 +14,6 @@ ...@@ -14,3 +14,6 @@
from .worker import Worker from .worker import Worker
from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
__all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
...@@ -12,14 +12,18 @@ ...@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum from enum import Enum, auto
from functools import wraps from functools import wraps
from types import FunctionType from types import FunctionType
from typing import Dict, List from typing import TYPE_CHECKING, Dict, List, Literal, Union
import ray import ray
from verl.protocol import DataProto, DataProtoFuture from ...protocol import DataProto, DataProtoFuture
if TYPE_CHECKING:
from .worker_group import WorkerGroup
# here we add a magic number of avoid user-defined function already have this attribute # here we add a magic number of avoid user-defined function already have this attribute
...@@ -27,13 +31,13 @@ MAGIC_ATTR = "attrs_3141562937" ...@@ -27,13 +31,13 @@ MAGIC_ATTR = "attrs_3141562937"
class Dispatch(Enum): class Dispatch(Enum):
RANK_ZERO = 0 RANK_ZERO = auto()
ONE_TO_ALL = 1 ONE_TO_ALL = auto()
ALL_TO_ALL = 2 ALL_TO_ALL = auto()
DP_COMPUTE = 3 DP_COMPUTE = auto()
DP_COMPUTE_PROTO = 4 DP_COMPUTE_PROTO = auto()
DP_COMPUTE_PROTO_WITH_FUNC = 5 DP_COMPUTE_PROTO_WITH_FUNC = auto()
DP_COMPUTE_METRIC = 6 DP_COMPUTE_METRIC = auto()
class Execute(Enum): class Execute(Enum):
...@@ -41,100 +45,85 @@ class Execute(Enum): ...@@ -41,100 +45,85 @@ class Execute(Enum):
RANK_ZERO = 1 RANK_ZERO = 1
def _split_args_kwargs_data_proto(chunks, *args, **kwargs): def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
splitted_args = [] splitted_args = []
for arg in args: for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture)) assert isinstance(arg, (DataProto, DataProtoFuture))
splitted_args.append(arg.chunk(chunks=chunks)) splitted_args.append(arg.chunk(chunks=chunks))
splitted_kwargs = {} splitted_kwargs = {}
for key, val in kwargs.items(): for key, value in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture)) assert isinstance(value, (DataProto, DataProtoFuture))
splitted_kwargs[key] = val.chunk(chunks=chunks) splitted_kwargs[key] = value.chunk(chunks=chunks)
return splitted_args, splitted_kwargs return splitted_args, splitted_kwargs
def dispatch_one_to_all(worker_group, *args, **kwargs): def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
args = tuple([arg] * worker_group.world_size for arg in args) args = tuple([arg] * worker_group.world_size for arg in args)
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
return args, kwargs return args, kwargs
def dispatch_all_to_all(worker_group, *args, **kwargs): def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
return args, kwargs return args, kwargs
def collect_all_to_all(worker_group, output): def collect_all_to_all(worker_group: "WorkerGroup", output):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
return output return output
def _concat_data_proto_or_future(output: List): def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
# make sure all the elements in output has the same type # make sure all the elements in output has the same type
for o in output: for output in outputs:
assert type(o) is type(output[0]) assert type(output) is type(outputs[0])
o = output[0] output = outputs[0]
if isinstance(o, DataProto): if isinstance(output, DataProto):
return DataProto.concat(output) return DataProto.concat(outputs)
elif isinstance(o, ray.ObjectRef): elif isinstance(output, ray.ObjectRef):
return DataProtoFuture.concat(output) return DataProtoFuture.concat(outputs)
else: else:
raise NotImplementedError raise NotImplementedError
def dispatch_dp_compute(worker_group, *args, **kwargs): def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup for arg in args:
assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
assert isinstance(worker_group, WorkerGroup)
return args, kwargs
for value in kwargs.values():
assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
def collect_dp_compute(worker_group, output): return args, kwargs
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert len(output) == worker_group.world_size
return output
def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
assert len(outputs) == worker_group.world_size
return outputs
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup) def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
return splitted_args, splitted_kwargs return splitted_args, splitted_kwargs
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert type(args[0]) is FunctionType # NOTE: The first one args is a function! assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
return splitted_args_with_func, splitted_kwargs return splitted_args_with_func, splitted_kwargs
def collect_dp_compute_data_proto(worker_group, output): def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
for o in output: for output in outputs:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}" assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
output = collect_dp_compute(worker_group, output) outputs = collect_dp_compute(worker_group, outputs)
return _concat_data_proto_or_future(output) return _concat_data_proto_or_future(outputs)
def get_predefined_dispatch_fn(dispatch_mode): def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
predefined_dispatch_mode_fn = { predefined_dispatch_mode_fn = {
Dispatch.ONE_TO_ALL: { Dispatch.ONE_TO_ALL: {
"dispatch_fn": dispatch_one_to_all, "dispatch_fn": dispatch_one_to_all,
...@@ -144,6 +133,10 @@ def get_predefined_dispatch_fn(dispatch_mode): ...@@ -144,6 +133,10 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_all_to_all, "dispatch_fn": dispatch_all_to_all,
"collect_fn": collect_all_to_all, "collect_fn": collect_all_to_all,
}, },
Dispatch.DP_COMPUTE: {
"dispatch_fn": dispatch_dp_compute,
"collect_fn": collect_dp_compute,
},
Dispatch.DP_COMPUTE_PROTO: { Dispatch.DP_COMPUTE_PROTO: {
"dispatch_fn": dispatch_dp_compute_data_proto, "dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute_data_proto, "collect_fn": collect_dp_compute_data_proto,
...@@ -152,11 +145,15 @@ def get_predefined_dispatch_fn(dispatch_mode): ...@@ -152,11 +145,15 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_dp_compute_data_proto_with_func, "dispatch_fn": dispatch_dp_compute_data_proto_with_func,
"collect_fn": collect_dp_compute_data_proto, "collect_fn": collect_dp_compute_data_proto,
}, },
Dispatch.DP_COMPUTE_METRIC: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute,
},
} }
return predefined_dispatch_mode_fn[dispatch_mode] return predefined_dispatch_mode_fn[dispatch_mode]
def get_predefined_execute_fn(execute_mode): def get_predefined_execute_fn(execute_mode: Execute):
""" """
Note that here we only asks execute_all and execute_rank_zero to be implemented Note that here we only asks execute_all and execute_rank_zero to be implemented
Leave the choice of how these two functions handle argument 'blocking' to users Leave the choice of how these two functions handle argument 'blocking' to users
...@@ -168,17 +165,17 @@ def get_predefined_execute_fn(execute_mode): ...@@ -168,17 +165,17 @@ def get_predefined_execute_fn(execute_mode):
return predefined_execute_mode_fn[execute_mode] return predefined_execute_mode_fn[execute_mode]
def _check_dispatch_mode(dispatch_mode): def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
assert isinstance(dispatch_mode, (Dispatch, Dict)), ( assert isinstance(dispatch_mode, (Dispatch, dict)), (
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
) )
if isinstance(dispatch_mode, Dict): if isinstance(dispatch_mode, dict):
necessary_keys = ["dispatch_fn", "collect_fn"] necessary_keys = ["dispatch_fn", "collect_fn"]
for key in necessary_keys: for key in necessary_keys:
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
def _check_execute_mode(execute_mode): def _check_execute_mode(execute_mode: Execute):
assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
...@@ -189,9 +186,10 @@ def _materialize_futures(*args, **kwargs): ...@@ -189,9 +186,10 @@ def _materialize_futures(*args, **kwargs):
arg = arg.get() arg = arg.get()
# add more type to materialize # add more type to materialize
new_args.append(arg) new_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, DataProtoFuture): for key, value in kwargs.items():
kwargs[k] = v.get() if isinstance(value, DataProtoFuture):
kwargs[key] = value.get()
new_args = tuple(new_args) new_args = tuple(new_args)
return new_args, kwargs return new_args, kwargs
......
...@@ -18,11 +18,13 @@ the class for Worker ...@@ -18,11 +18,13 @@ the class for Worker
import os import os
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import ray import ray
import torch
from verl.single_controller.base.decorator import Dispatch, Execute, register from .decorator import Dispatch, Execute, register
from verl.single_controller.base.register_center.ray import create_worker_group_register_center from .register_center.ray import create_worker_group_register_center
@dataclass @dataclass
...@@ -40,7 +42,7 @@ class DistGlobalInfo: ...@@ -40,7 +42,7 @@ class DistGlobalInfo:
class WorkerHelper: class WorkerHelper:
def _get_node_ip(self): def _get_node_ip(self) -> str:
host_ipv4 = os.getenv("MY_HOST_IP", None) host_ipv4 = os.getenv("MY_HOST_IP", None)
host_ipv6 = os.getenv("MY_HOST_IPV6", None) host_ipv6 = os.getenv("MY_HOST_IPV6", None)
host_ip_by_env = host_ipv4 or host_ipv6 host_ip_by_env = host_ipv4 or host_ipv6
...@@ -49,12 +51,12 @@ class WorkerHelper: ...@@ -49,12 +51,12 @@ class WorkerHelper:
host_ip = host_ip_by_env or host_ip_by_sdk host_ip = host_ip_by_env or host_ip_by_sdk
return host_ip return host_ip
def _get_free_port(self): def _get_free_port(self) -> int:
with socket.socket() as sock: with socket.socket() as sock:
sock.bind(("", 0)) sock.bind(("", 0))
return sock.getsockname()[1] return sock.getsockname()[1]
def get_availale_master_addr_port(self): def get_availale_master_addr_port(self) -> Tuple[str, str]:
return self._get_node_ip(), str(self._get_free_port()) return self._get_node_ip(), str(self._get_free_port())
def _get_pid(self): def _get_pid(self):
...@@ -81,16 +83,26 @@ class WorkerMeta: ...@@ -81,16 +83,26 @@ class WorkerMeta:
# we assume that in each WorkerGroup, there is a Master Worker # we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper): class Worker(WorkerHelper):
"""A (distributed) worker."""
_world_size: int
_rank: int
_local_world_size: int
_local_rank: int
_master_addr: str
_master_port: str
_cuda_visible_devices: str
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
instance = super().__new__(cls) instance = super().__new__(cls)
# note that here we use int to distinguish # note that here we use int to distinguish
disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0)) disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
if disable_worker_init: if disable_worker_init:
return instance return instance
rank = os.environ.get("RANK", None) rank = os.getenv("RANK", None)
worker_group_prefix = os.environ.get("WG_PREFIX", None) worker_group_prefix = os.getenv("WG_PREFIX", None)
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
...@@ -112,13 +124,19 @@ class Worker(WorkerHelper): ...@@ -112,13 +124,19 @@ class Worker(WorkerHelper):
def __init__(self, cuda_visible_devices=None) -> None: def __init__(self, cuda_visible_devices=None) -> None:
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
world_size = int(os.environ["WORLD_SIZE"]) world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.environ["RANK"]) rank = int(os.getenv("RANK"))
self._rank = rank self._rank = rank
self._world_size = world_size self._world_size = world_size
master_addr = os.environ["MASTER_ADDR"] if "AMD" in torch.cuda.get_device_name():
master_port = os.environ["MASTER_PORT"] os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices))
master_addr = os.getenv("MASTER_ADDR")
master_port = os.getenv("MASTER_PORT")
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
...@@ -149,6 +167,7 @@ class Worker(WorkerHelper): ...@@ -149,6 +167,7 @@ class Worker(WorkerHelper):
if val is not None: if val is not None:
# print(f"set {key} to {val}") # print(f"set {key} to {val}")
os.environ[key] = str(val) os.environ[key] = str(val)
os.environ["REDIS_STORE_SERVER_HOST"] = ( os.environ["REDIS_STORE_SERVER_HOST"] = (
str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
) )
...@@ -157,7 +176,7 @@ class Worker(WorkerHelper): ...@@ -157,7 +176,7 @@ class Worker(WorkerHelper):
return self._master_addr, self._master_port return self._master_addr, self._master_port
def get_cuda_visible_devices(self): def get_cuda_visible_devices(self):
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set") cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
return cuda_visible_devices return cuda_visible_devices
def print_rank0(self, *args, **kwargs): def print_rank0(self, *args, **kwargs):
......
...@@ -19,20 +19,20 @@ import logging ...@@ -19,20 +19,20 @@ import logging
import signal import signal
import threading import threading
import time import time
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List, Optional
from verl.single_controller.base.decorator import ( from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
MAGIC_ATTR,
Dispatch,
get_predefined_dispatch_fn,
get_predefined_execute_fn,
)
class ResourcePool: class ResourcePool:
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None: """The resource pool with meta info such as world size."""
def __init__(
self, process_on_nodes: Optional[Any] = None, max_collocate_count: int = 10, n_gpus_per_node: int = 8
) -> None:
if process_on_nodes is None: if process_on_nodes is None:
process_on_nodes = [] process_on_nodes = []
self._store = process_on_nodes self._store = process_on_nodes
self.max_collocate_count = max_collocate_count self.max_collocate_count = max_collocate_count
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
...@@ -73,28 +73,23 @@ class ClassWithInitArgs: ...@@ -73,28 +73,23 @@ class ClassWithInitArgs:
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
# def add_arg(self, arg):
# self.args += (arg,)
# def add_kwarg(self, key, value):
# self.kwargs[key] = value
def __call__(self) -> Any: def __call__(self) -> Any:
return self.cls(*self.args, **self.kwargs) return self.cls(*self.args, **self.kwargs)
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
import time
while True: while True:
for worker in workers: for worker in workers:
if not is_alive(worker): if not is_alive(worker):
logging.warning(f"Worker {worker} is not alive, sending signal to main thread") logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
signal.raise_signal(signal.SIGABRT) signal.raise_signal(signal.SIGABRT)
time.sleep(gap_time) time.sleep(gap_time)
class WorkerGroup: class WorkerGroup:
"""A group of workers"""
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
self._is_init_with_detached_workers = True if resource_pool is None else False self._is_init_with_detached_workers = True if resource_pool is None else False
...@@ -136,14 +131,10 @@ class WorkerGroup: ...@@ -136,14 +131,10 @@ class WorkerGroup:
def world_size(self): def world_size(self):
return len(self._workers) return len(self._workers)
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
# MegatronWorkerGroup, XperfWorkerGroup should skip
def _bind_worker_method(self, user_defined_cls, func_generator): def _bind_worker_method(self, user_defined_cls, func_generator):
""" """
Bind the worker method to the WorkerGroup Bind the worker method to the WorkerGroup
""" """
for method_name in dir(user_defined_cls): for method_name in dir(user_defined_cls):
try: try:
method = getattr(user_defined_cls, method_name) method = getattr(user_defined_cls, method_name)
......
...@@ -13,27 +13,28 @@ ...@@ -13,27 +13,28 @@
# limitations under the License. # limitations under the License.
import os import os
import random
import re
import string
import time import time
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch from unittest.mock import patch
import ray import ray
from ray.actor import ActorHandle
from ray.experimental.state.api import get_actor from ray.experimental.state.api import get_actor
from ray.util import list_named_actors from ray.util import list_named_actors
from ray.util.placement_group import PlacementGroup, placement_group from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy
from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup from ..base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
from verl.single_controller.base.decorator import MAGIC_ATTR from ..base.decorator import MAGIC_ATTR
__all__ = ["Worker"] __all__ = ["Worker"]
def get_random_string(length: int) -> str: def get_random_string(length: int) -> str:
import random
import string
letters_digits = string.ascii_letters + string.digits letters_digits = string.ascii_letters + string.digits
return "".join(random.choice(letters_digits) for _ in range(length)) return "".join(random.choice(letters_digits) for _ in range(length))
...@@ -50,6 +51,27 @@ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, block ...@@ -50,6 +51,27 @@ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, block
return func return func
def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
"""
Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK
to be consistent across nodes when resume from checkpoint.
With this function, if there's only one resource pool and there's no node change, RANK should be consistent
across nodes in multiple ray jobs, even if the whole ray cluster is restarted.
"""
node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
pg_ip = {}
for pg in pgs:
specs = ray._private.state.state.placement_group_table(pg.id)
# all bunles should be on the same node
node_id = specs["bundles_to_node_id"][0]
pg_ip[pg.id] = node_ip[node_id]
return sorted(pgs, key=lambda pg: pg_ip[pg.id])
class RayResourcePool(ResourcePool): class RayResourcePool(ResourcePool):
def __init__( def __init__(
self, self,
...@@ -57,7 +79,7 @@ class RayResourcePool(ResourcePool): ...@@ -57,7 +79,7 @@ class RayResourcePool(ResourcePool):
use_gpu: bool = True, use_gpu: bool = True,
name_prefix: str = "", name_prefix: str = "",
max_colocate_count: int = 5, max_colocate_count: int = 5,
detached=False, detached: bool = False,
) -> None: ) -> None:
super().__init__(process_on_nodes, max_colocate_count) super().__init__(process_on_nodes, max_colocate_count)
self.use_gpu = use_gpu self.use_gpu = use_gpu
...@@ -66,7 +88,7 @@ class RayResourcePool(ResourcePool): ...@@ -66,7 +88,7 @@ class RayResourcePool(ResourcePool):
self.pgs = None self.pgs = None
self.detached = detached self.detached = detached
def get_placement_groups(self, strategy="STRICT_PACK", name=None): def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str] = None) -> List[PlacementGroup]:
if self.pgs is not None: if self.pgs is not None:
return self.pgs return self.pgs
...@@ -97,7 +119,7 @@ class RayResourcePool(ResourcePool): ...@@ -97,7 +119,7 @@ class RayResourcePool(ResourcePool):
def extract_pg_from_exist( def extract_pg_from_exist(
resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool
) -> List: ) -> List[PlacementGroup]:
src_pgs = [ src_pgs = [
pg pg
for role_name, resource_pool in resource_pools.items() for role_name, resource_pool in resource_pools.items()
...@@ -151,7 +173,12 @@ class RayClassWithInitArgs(ClassWithInitArgs): ...@@ -151,7 +173,12 @@ class RayClassWithInitArgs(ClassWithInitArgs):
self._options.update(options) self._options.update(options)
def __call__( def __call__(
self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None self,
placement_group: PlacementGroup,
placement_group_bundle_idx: int,
use_gpu: bool = True,
num_gpus: int = 1,
sharing_with: Worker = None,
) -> Any: ) -> Any:
if sharing_with is not None: if sharing_with is not None:
target_node_id = ray.get(sharing_with.get_node_id.remote()) target_node_id = ray.get(sharing_with.get_node_id.remote())
...@@ -188,8 +215,8 @@ class RayWorkerGroup(WorkerGroup): ...@@ -188,8 +215,8 @@ class RayWorkerGroup(WorkerGroup):
ray_cls_with_init: RayClassWithInitArgs = None, ray_cls_with_init: RayClassWithInitArgs = None,
bin_pack: bool = True, bin_pack: bool = True,
name_prefix: str = None, name_prefix: str = None,
detached=False, detached: bool = False,
worker_names=None, worker_names: List[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(resource_pool=resource_pool, **kwargs) super().__init__(resource_pool=resource_pool, **kwargs)
...@@ -210,21 +237,24 @@ class RayWorkerGroup(WorkerGroup): ...@@ -210,21 +237,24 @@ class RayWorkerGroup(WorkerGroup):
if ray_cls_with_init is not None: if ray_cls_with_init is not None:
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
def _is_worker_alive(self, worker: ray.actor.ActorHandle): def _is_worker_alive(self, worker: ActorHandle) -> bool:
worker_state_dict = get_actor(worker._actor_id.hex()) worker_state_dict = get_actor(worker._actor_id.hex())
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
def _init_with_detached_workers(self, worker_names): def _init_with_detached_workers(self, worker_names: List[str]) -> None:
workers = [ray.get_actor(name=name) for name in worker_names] workers = [ray.get_actor(name=name) for name in worker_names]
self._workers = workers self._workers = workers
self._world_size = len(worker_names) self._world_size = len(worker_names)
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): def _init_with_resource_pool(
self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, bin_pack: bool, detached: bool
):
use_gpu = resource_pool.use_gpu use_gpu = resource_pool.use_gpu
strategy = "PACK" strategy = "PACK"
if bin_pack: if bin_pack:
strategy = "STRICT_PACK" strategy = "STRICT_PACK"
pgs = resource_pool.get_placement_groups(strategy=strategy) pgs = resource_pool.get_placement_groups(strategy=strategy)
world_size = resource_pool.world_size world_size = resource_pool.world_size
self._world_size = world_size self._world_size = world_size
...@@ -232,8 +262,8 @@ class RayWorkerGroup(WorkerGroup): ...@@ -232,8 +262,8 @@ class RayWorkerGroup(WorkerGroup):
num_gpus = 1 / resource_pool.max_collocate_count num_gpus = 1 / resource_pool.max_collocate_count
rank = -1 rank = -1
for pg_idx, local_world_size in enumerate(resource_pool.store): local_world_size = resource_pool.store[0]
pg = pgs[pg_idx] for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):
assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the "
for local_rank in range(local_world_size): for local_rank in range(local_world_size):
rank += 1 rank += 1
...@@ -251,8 +281,6 @@ class RayWorkerGroup(WorkerGroup): ...@@ -251,8 +281,6 @@ class RayWorkerGroup(WorkerGroup):
env_vars["MASTER_ADDR"] = self._master_addr env_vars["MASTER_ADDR"] = self._master_addr
env_vars["MASTER_PORT"] = self._master_port env_vars["MASTER_PORT"] = self._master_port
import re
cia_name = type(ray_cls_with_init.cls).__name__ cia_name = type(ray_cls_with_init.cls).__name__
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
...@@ -272,7 +300,7 @@ class RayWorkerGroup(WorkerGroup): ...@@ -272,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
if rank == 0: if rank == 0:
register_center_actor = None register_center_actor = None
for _ in range(120): for _ in range(360):
if f"{self.name_prefix}_register_center" not in list_named_actors(): if f"{self.name_prefix}_register_center" not in list_named_actors():
time.sleep(1) time.sleep(1)
else: else:
......
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
from dataclasses import asdict, dataclass, field, fields, is_dataclass from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from verl.workers.config import WorkerConfig from ..workers.config import WorkerConfig
def recursive_post_init(dataclass_obj): def recursive_post_init(dataclass_obj):
...@@ -36,12 +36,13 @@ class DataConfig: ...@@ -36,12 +36,13 @@ class DataConfig:
train_files: str = "" train_files: str = ""
val_files: str = "" val_files: str = ""
prompt_key: str = "prompt" prompt_key: str = "prompt"
answer_key: str = "answer"
image_key: str = "images"
max_prompt_length: int = 512 max_prompt_length: int = 512
max_response_length: int = 512 max_response_length: int = 512
rollout_batch_size: int = 512 rollout_batch_size: int = 512
return_raw_input_ids: bool = False val_batch_size: int = -1
return_raw_prompt: bool = False system_prompt: Optional[str] = None
system_prompt: str = r"Please reason step by step, and put your final answer within \boxed{}."
shuffle: bool = True shuffle: bool = True
seed: int = 1 seed: int = 1
max_pixels: int = 4194304 max_pixels: int = 4194304
...@@ -52,10 +53,12 @@ class DataConfig: ...@@ -52,10 +53,12 @@ class DataConfig:
class AlgorithmConfig: class AlgorithmConfig:
gamma: float = 1.0 gamma: float = 1.0
lam: float = 1.0 lam: float = 1.0
adv_estimator: str = "gae" adv_estimator: str = "grpo"
disable_kl: bool = False
use_kl_loss: bool = False
kl_penalty: str = "kl" kl_penalty: str = "kl"
kl_type: str = "fixed"
kl_coef: float = 1e-3 kl_coef: float = 1e-3
kl_type: str = "fixed"
kl_horizon: float = 0.0 kl_horizon: float = 0.0
kl_target: float = 0.0 kl_target: float = 0.0
...@@ -67,18 +70,17 @@ class TrainerConfig: ...@@ -67,18 +70,17 @@ class TrainerConfig:
project_name: str = "easy_r1" project_name: str = "easy_r1"
experiment_name: str = "demo" experiment_name: str = "demo"
logger: Tuple[str] = ("console", "wandb") logger: Tuple[str] = ("console", "wandb")
val_generations_to_log_to_wandb: int = 0
nnodes: int = 1 nnodes: int = 1
n_gpus_per_node: int = 8 n_gpus_per_node: int = 8
save_freq: int = -1 critic_warmup: int = 0
load_checkpoint_path: Optional[str] = None val_freq: int = -1
val_before_train: bool = True val_before_train: bool = True
val_only: bool = False val_only: bool = False
test_freq: int = -1 val_generations_to_log: int = 0
critic_warmup: int = 0 save_freq: int = -1
remove_previous_ckpt: bool = False save_limit: int = -1
del_local_ckpt_after_load: bool = False
save_checkpoint_path: Optional[str] = None save_checkpoint_path: Optional[str] = None
load_checkpoint_path: Optional[str] = None
def post_init(self): def post_init(self):
if self.save_checkpoint_path is None: if self.save_checkpoint_path is None:
...@@ -95,6 +97,10 @@ class PPOConfig: ...@@ -95,6 +97,10 @@ class PPOConfig:
def post_init(self): def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length self.worker.rollout.response_length = self.data.max_response_length
self.worker.actor.disable_kl = self.algorithm.disable_kl
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
self.worker.actor.kl_penalty = self.algorithm.kl_penalty
self.worker.actor.kl_coef = self.algorithm.kl_coef
def deep_post_init(self): def deep_post_init(self):
recursive_post_init(self) recursive_post_init(self)
......
# Copyright 2022 The HuggingFace Team
# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,50 +18,55 @@ The function implemented in this file should be used by trainer with different d ...@@ -18,50 +18,55 @@ The function implemented in this file should be used by trainer with different d
implement PPO implement PPO
""" """
from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import verl.utils.torch_functional as verl_F from ..utils import torch_functional as VF
if TYPE_CHECKING: if TYPE_CHECKING:
from verl.trainer.config import AlgorithmConfig from .config import AlgorithmConfig
class AdaptiveKLController: class KLController(ABC):
""" @abstractmethod
Adaptive KL controller described in the paper: def update(self, current_kl: float, n_steps: int) -> None: ...
https://arxiv.org/pdf/1909.08593.pdf
"""
class AdaptiveKLController(KLController):
"""Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf"""
def __init__(self, init_kl_coef: float, target_kl: float, horizon: float): def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
self.value = init_kl_coef self.value = init_kl_coef
self.target = target_kl self.target = target_kl
self.horizon = horizon self.horizon = horizon
def update(self, current_kl, n_steps): def update(self, current_kl: float, n_steps: int) -> None:
target = self.target target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult self.value *= mult
class FixedKLController: class FixedKLController(KLController):
"""Fixed KL controller.""" """Fixed KL controller."""
def __init__(self, kl_coef: float): def __init__(self, init_kl_coef: float):
self.value = kl_coef self.value = init_kl_coef
def update(self, current_kl, n_steps): def update(self, current_kl: float, n_steps: int) -> None:
pass pass
def get_kl_controller(algorithm_config: "AlgorithmConfig"): def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
"""Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319"""
if algorithm_config.kl_type == "fixed": if algorithm_config.kl_type == "fixed":
kl_ctrl = FixedKLController(kl_coef=algorithm_config.kl_coef) kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
elif algorithm_config.kl_type == "adaptive": elif algorithm_config.kl_type == "adaptive":
assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}." assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
kl_ctrl = AdaptiveKLController( kl_ctrl = AdaptiveKLController(
...@@ -70,19 +75,20 @@ def get_kl_controller(algorithm_config: "AlgorithmConfig"): ...@@ -70,19 +75,20 @@ def get_kl_controller(algorithm_config: "AlgorithmConfig"):
horizon=algorithm_config.kl_horizon, horizon=algorithm_config.kl_horizon,
) )
else: else:
raise ValueError("Unknown kl_ctrl type") raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")
return kl_ctrl return kl_ctrl
@torch.no_grad()
def compute_gae_advantage_return( def compute_gae_advantage_return(
token_level_rewards: torch.Tensor, token_level_rewards: torch.Tensor,
values: torch.Tensor, values: torch.Tensor,
eos_mask: torch.Tensor, eos_mask: torch.Tensor,
gamma: torch.Tensor, gamma: torch.Tensor,
lam: torch.Tensor, lam: torch.Tensor,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py """Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513
Args: Args:
token_level_rewards: `(torch.Tensor)` token_level_rewards: `(torch.Tensor)`
...@@ -103,27 +109,26 @@ def compute_gae_advantage_return( ...@@ -103,27 +109,26 @@ def compute_gae_advantage_return(
shape: (bs, response_length) shape: (bs, response_length)
""" """
with torch.no_grad(): lastgaelam = 0
lastgaelam = 0 advantages_reversed = []
advantages_reversed = [] gen_len = token_level_rewards.shape[-1]
gen_len = token_level_rewards.shape[-1] for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
for t in reversed(range(gen_len)): delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 lastgaelam = delta + gamma * lam * lastgaelam
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] advantages_reversed.append(lastgaelam)
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam) advantages = torch.stack(advantages_reversed[::-1], dim=1)
advantages = torch.stack(advantages_reversed[::-1], dim=1) returns = (advantages + values) * eos_mask
advantages = VF.masked_whiten(advantages, eos_mask) * eos_mask
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, eos_mask)
return advantages, returns return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@torch.no_grad()
def compute_grpo_outcome_advantage( def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6 token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Compute advantage for GRPO, operating only on Outcome reward Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response). (with only one scalar reward for each response).
...@@ -133,6 +138,50 @@ def compute_grpo_outcome_advantage( ...@@ -133,6 +138,50 @@ def compute_grpo_outcome_advantage(
eos_mask: `(torch.Tensor)` eos_mask: `(torch.Tensor)`
shape: (bs, response_length) shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean, id2std = {}, {}
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
@torch.no_grad()
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns: Returns:
advantages: `(torch.Tensor)` advantages: `(torch.Tensor)`
shape: (bs, response_length) shape: (bs, response_length)
...@@ -144,31 +193,33 @@ def compute_grpo_outcome_advantage( ...@@ -144,31 +193,33 @@ def compute_grpo_outcome_advantage(
id2score = defaultdict(list) id2score = defaultdict(list)
id2mean = {} id2mean = {}
id2std = {} bsz = scores.shape[0]
for i in range(bsz):
with torch.no_grad(): id2score[index[i]].append(scores[i])
bsz = scores.shape[0]
for i in range(bsz): for idx in id2score:
id2score[index[i]].append(scores[i]) if len(id2score[idx]) == 1:
for idx in id2score: id2mean[idx] = torch.tensor(0.0)
if len(id2score[idx]) == 1: elif len(id2score[idx]) > 1:
id2mean[idx] = torch.tensor(0.0) id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.tensor(1.0) else:
elif len(id2score[idx]) > 1: raise ValueError(f"no score in prompt index: {idx}.")
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]])) for i in range(bsz):
else: response_num = len(id2score[index[i]])
raise ValueError(f"no score in prompt index: {idx}") if response_num > 1:
for i in range(bsz): scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) response_num - 1
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask )
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores return scores, scores
@torch.no_grad()
def compute_reinforce_plus_plus_outcome_advantage( def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Compute advantage for REINFORCE++. Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262 This implementation is based on the paper: https://arxiv.org/abs/2501.03262
...@@ -184,26 +235,24 @@ def compute_reinforce_plus_plus_outcome_advantage( ...@@ -184,26 +235,24 @@ def compute_reinforce_plus_plus_outcome_advantage(
Returns: `(torch.Tensor)` Returns: `(torch.Tensor)`
shape: (bs, response_length) shape: (bs, response_length)
""" """
returns = torch.zeros_like(token_level_rewards)
with torch.no_grad(): running_return = 0
returns = torch.zeros_like(token_level_rewards) for t in reversed(range(token_level_rewards.shape[1])):
running_return = 0 running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
for t in reversed(range(token_level_rewards.shape[1])): # Reset after EOS
running_return = token_level_rewards[:, t] + gamma * running_return running_return = running_return * eos_mask[:, t]
returns[:, t] = running_return
# Reset after EOS advantages = VF.masked_whiten(returns, eos_mask)
running_return = running_return * eos_mask[:, t] advantages *= eos_mask
returns *= eos_mask
advantages = verl_F.masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask
return advantages, returns return advantages, returns
@torch.no_grad()
def compute_remax_outcome_advantage( def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Compute advantage for ReMax, operating only on Outcome reward Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505 This implementation is based on the paper: https://arxiv.org/abs/2310.10505
...@@ -225,23 +274,31 @@ def compute_remax_outcome_advantage( ...@@ -225,23 +274,31 @@ def compute_remax_outcome_advantage(
""" """
response_length = token_level_rewards.shape[-1] response_length = token_level_rewards.shape[-1]
# scores = token_level_rewards.sum(dim=-1) # scores = token_level_rewards.sum(dim=-1)
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) * eos_mask
with torch.no_grad(): advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
return advantages, returns return advantages, returns
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): def compute_rewards(
kl = old_log_prob - ref_log_prob token_level_scores: torch.Tensor,
log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
kl_ratio: float,
) -> torch.Tensor:
kl = log_probs - ref_log_probs
return token_level_scores - kl * kl_ratio return token_level_scores - kl * kl_ratio
def compute_policy_loss( def compute_policy_loss(
old_log_prob, log_prob, advantages, eos_mask, cliprange old_log_probs: torch.Tensor,
log_probs: torch.Tensor,
advantages: torch.Tensor,
eos_mask: torch.Tensor,
cliprange: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 """Compute the policy loss.
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
Args: Args:
old_log_prob: `(torch.Tensor)` old_log_prob: `(torch.Tensor)`
...@@ -260,95 +317,88 @@ def compute_policy_loss( ...@@ -260,95 +317,88 @@ def compute_policy_loss(
policy gradient loss computed via PPO policy gradient loss computed via PPO
pg_clipfrac: (float) pg_clipfrac: (float)
a float number indicating the fraction of policy gradient loss being clipped a float number indicating the fraction of policy gradient loss being clipped
""" """
negative_approx_kl = log_prob - old_log_prob negative_approx_kl = log_probs - old_log_probs
# clamp the ratio before exp to avoid nan
# see: https://github.com/pytorch/pytorch/issues/10729
ratio = torch.exp(negative_approx_kl) ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask) clipped_ratio = torch.exp(torch.clamp(negative_approx_kl, np.log(1.0 - cliprange), np.log(1.0 + cliprange)))
ppo_kl = VF.masked_mean(-negative_approx_kl, eos_mask)
pg_losses = -advantages * ratio pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) pg_losses2 = -advantages * clipped_ratio
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) pg_loss = VF.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) pg_clipfrac = VF.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl return pg_loss, pg_clipfrac, ppo_kl
def compute_entropy_loss(logits, eos_mask): def compute_value_loss(
"""Compute Categorical entropy loss vpreds: torch.Tensor,
returns: torch.Tensor,
Args: values: torch.Tensor,
logits: `(torch.Tensor)` eos_mask: torch.Tensor,
shape: (bs, response_length, vocab_size) cliprange_value: float,
eos_mask: `(torch.Tensor)` ) -> Tuple[torch.Tensor, float]:
shape: (bs, response_length) """Compute the value loss.
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
return entropy_loss
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value): Copied from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args: Args:
vpreds (`torch.FloatTensor`): vpreds (`torch.FloatTensor`):
Predicted values of the value head, shape (`batch_size`, `response_length`) Predicted values of the value head, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
returns: (`torch.FloatTensor`): returns: (`torch.FloatTensor`):
Ground truth returns, shape (`batch_size`, `response_length`) Ground truth returns, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange_value: (float)
The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
Returns: Returns:
vf_loss: a scalar (`torch.FloatTensor`): vf_loss: a scalar (`torch.FloatTensor`):
value function loss value function loss
vf_clipfrac: a float vf_clipfrac: a float
The ratio of vf being clipped The ratio of vf being clipped
""" """
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2 vf_losses1 = torch.square(vpreds - returns)
vf_losses2 = (vpredclipped - returns) ** 2 vf_losses2 = torch.square(vpredclipped - returns)
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask) vf_loss = 0.5 * VF.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask) vf_clipfrac = VF.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
return vf_loss, vf_clipfrac return vf_loss, vf_clipfrac
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.Tensor: def kl_penalty(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
"""Compute KL divergence given logprob and ref_logprob. """Compute KL divergence given log_probs and ref_log_probs.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
Args: Args:
logprob: log_probs: torch.Tensor
ref_logprob: ref_log_probs: torch.Tensor
Returns: Returns:
kl_div: torch.Tensor
""" """
log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
if kl_penalty == "kl": if kl_penalty == "kl":
return logprob - ref_logprob return log_probs - ref_log_probs
if kl_penalty == "abs": if kl_penalty == "abs":
return (logprob - ref_logprob).abs() return (log_probs - ref_log_probs).abs()
if kl_penalty == "mse": if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square() return 0.5 * (log_probs - ref_log_probs).square()
# J. Schulman. Approximating kl divergence, 2020. # J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html. # URL http://joschu.net/blog/kl-approx.html
if kl_penalty == "low_var_kl": if kl_penalty == "low_var_kl":
kl = ref_logprob - logprob kl = ref_log_probs - log_probs
ratio = torch.exp(kl) kld = (kl.exp() - kl - 1).contiguous()
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10) return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full": if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
raise NotImplementedError
raise NotImplementedError raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")
...@@ -16,81 +16,100 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by o ...@@ -16,81 +16,100 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by o
""" """
import json import json
import torch
import ray import ray
from omegaconf import OmegaConf from omegaconf import OmegaConf
from verl.single_controller.ray import RayWorkerGroup from ..single_controller.ray import RayWorkerGroup
from verl.trainer.config import PPOConfig from ..utils.tokenizer import get_processor, get_tokenizer
from verl.trainer.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role from ..workers.fsdp_workers import FSDPWorker
from verl.utils import get_processor, get_tokenizer from ..workers.reward import CustomRewardManager
from verl.workers.fsdp_workers import FSDPWorker from .config import PPOConfig
from verl.workers.reward import CustomRewardManager from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
# please make sure main_task is not scheduled on head
@ray.remote(num_cpus=1)
class Runner:
"""A runner for RL training."""
def run(self, config: PPOConfig):
# print config
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer
tokenizer = get_tokenizer(
config.worker.actor.model.model_path,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
processor = get_processor(
config.worker.actor.model.model_path,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
# define worker classes
ray_worker_group_cls = RayWorkerGroup
role_worker_mapping = {
Role.ActorRollout: ray.remote(FSDPWorker),
Role.Critic: ray.remote(FSDPWorker),
Role.RefPolicy: ray.remote(FSDPWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
def main(): def main():
cli_args = OmegaConf.from_cli() cli_args = OmegaConf.from_cli()
file_config = OmegaConf.load(cli_args.config)
del cli_args.config
default_config = OmegaConf.structured(PPOConfig()) default_config = OmegaConf.structured(PPOConfig())
ppo_config = OmegaConf.merge(default_config, file_config, cli_args)
if hasattr(cli_args, "config"):
config_path = cli_args.pop("config", None)
file_config = OmegaConf.load(config_path)
default_config = OmegaConf.merge(default_config, file_config)
ppo_config = OmegaConf.merge(default_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config) ppo_config = OmegaConf.to_object(ppo_config)
# this is for local ray cluster
if not ray.is_initialized(): if not ray.is_initialized():
# this is for local ray cluster # for rocm
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(),
ray.get(main_task.remote(ppo_config)) ignore_reinit_error=True,
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
else:
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
def main_task(config: PPOConfig):
config.deep_post_init() runner = Runner.remote()
print(json.dumps(config.to_dict(), indent=2)) ray.get(runner.run.remote(ppo_config))
# instantiate tokenizer
tokenizer = get_tokenizer(config.worker.actor.model.model_path)
processor = get_processor(config.worker.actor.model.model_path, use_fast=True)
# define worker classes
ray_worker_group_cls = RayWorkerGroup
role_worker_mapping = {
Role.ActorRollout: ray.remote(FSDPWorker),
Role.Critic: ray.remote(FSDPWorker),
Role.RefPolicy: ray.remote(FSDPWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
val_reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List
import numpy as np
import torch
from ..protocol import DataProto
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
return {key: np.mean(value) for key, value in metrics.items()}
def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]:
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
max_response_length = batch.batch["responses"].size(-1)
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float()
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
num_response_tokens = torch.sum(batch.batch["response_mask"]).item()
num_overall_tokens = sum(batch.meta_info["global_token_num"])
num_tokens_of_section = {
**dict.fromkeys(["gen", "reward"], num_response_tokens),
**dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens),
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
total_num_tokens = sum(batch.meta_info["global_token_num"])
time = timing_raw["step"]
return {
"perf/total_num_tokens": total_num_tokens,
"perf/time_per_step": time,
"perf/throughput": total_num_tokens / (time * n_gpus),
}
...@@ -18,54 +18,71 @@ This trainer supports model-agonistic model initialization with huggingface ...@@ -18,54 +18,71 @@ This trainer supports model-agonistic model initialization with huggingface
import os import os
import uuid import uuid
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum, IntEnum, auto
from pprint import pprint from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Dict, Optional, Type
import numpy as np import numpy as np
import ray
import torch import torch
from codetiming import Timer from codetiming import Timer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from ray.experimental.tqdm_ray import tqdm
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from verl import DataProto from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from ..single_controller.base import Worker
from verl.single_controller.base import Worker from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from ..single_controller.ray.base import create_colocated_worker_cls
from verl.single_controller.ray.base import create_colocated_worker_cls from ..utils import torch_functional as VF
from verl.trainer import core_algos from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from verl.trainer.config import PPOConfig from ..utils.dataset import RLHFDataset, collate_fn
from verl.utils.rl_dataset import RLHFDataset, collate_fn from ..utils.logger import Tracker
from verl.utils.torch_functional import masked_mean from ..utils.py_functional import convert_dict_to_str
from verl.utils.tracking import Tracking from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.workers.fsdp_workers import FSDPWorker from ..workers.fsdp_workers import FSDPWorker
from . import core_algos
from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
WorkerType = Type[Worker] WorkerType = Type[Worker]
class Role(Enum): class Role(IntEnum):
""" """
To create more roles dynamically, you can subclass Role and add new members To create more roles dynamically, you can subclass Role and add new members
""" """
Actor = 0 Actor = auto()
Rollout = 1 Rollout = auto()
ActorRollout = 2 ActorRollout = auto()
Critic = 3 Critic = auto()
RefPolicy = 4 RefPolicy = auto()
RewardModel = 5 RewardModel = auto()
ActorRolloutRef = 6 ActorRolloutRef = auto()
class AdvantageEstimator(str, Enum):
"""
Using an enumeration class to avoid spelling errors in adv_estimator
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REMAX = "remax"
RLOO = "rloo"
@dataclass @dataclass
class ResourcePoolManager: class ResourcePoolManager:
""" """
Define a resource pool specification. Resource pool will be initialized first. Define a resource pool specification. Resource pool will be initialized first.
Mapping
""" """
resource_pool_spec: dict[str, list[int]] resource_pool_spec: dict[str, list[int]]
...@@ -82,23 +99,41 @@ class ResourcePoolManager: ...@@ -82,23 +99,41 @@ class ResourcePoolManager:
) )
self.resource_pool_dict[resource_pool_name] = resource_pool self.resource_pool_dict[resource_pool_name] = resource_pool
self._check_resource_available()
def get_resource_pool(self, role: Role) -> RayResourcePool: def get_resource_pool(self, role: Role) -> RayResourcePool:
"""Get the resource pool of the worker_cls""" """Get the resource pool of the worker."""
return self.resource_pool_dict[self.mapping[role]] return self.resource_pool_dict[self.mapping[role]]
def get_n_gpus(self) -> int:
"""Get the number of gpus in this cluster."""
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): def _check_resource_available(self):
responses = data.batch["responses"] """Check if the resource pool can be satisfied in this ray cluster."""
response_length = responses.size(1) node_available_resources = ray.state.available_resources_per_node()
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}
# check total required gpus can be satisfied
total_available_gpus = sum(node_available_gpus.values())
total_required_gpus = sum(
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
)
if total_available_gpus < total_required_gpus:
raise ValueError(
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}."
)
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"):
token_level_scores = data.batch["token_level_scores"] token_level_scores = data.batch["token_level_scores"]
batch_size = data.batch.batch_size[0] batch_size = data.batch.batch_size[0]
attention_mask = data.batch["attention_mask"] response_mask = data.batch["response_mask"]
response_mask = attention_mask[:, -response_length:]
# compute kl between ref_policy and current policy # compute kl between ref_policy and current policy
if "ref_log_prob" in data.batch.keys(): if "ref_log_probs" in data.batch.keys():
kld = core_algos.kl_penalty( kld = core_algos.kl_penalty(
data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty
) # (batch_size, response_length) ) # (batch_size, response_length)
kld = kld * response_mask kld = kld * response_mask
beta = kl_ctrl.value beta = kl_ctrl.value
...@@ -108,191 +143,49 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, ...@@ -108,191 +143,49 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
token_level_rewards = token_level_scores - beta * kld token_level_rewards = token_level_scores - beta * kld
current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item() current_kl = torch.mean(current_kl, dim=0).item()
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 # According to https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L880
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch["token_level_rewards"] = token_level_rewards data.batch["token_level_rewards"] = token_level_rewards
metrics = {"critic/kl": current_kl, "critic/kl_coeff": beta} metrics = {"critic/kl": current_kl, "critic/kl_coef": beta}
return data, metrics return data, metrics
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0):
# prepare response group token_level_rewards = data.batch["token_level_rewards"]
# TODO: add other ways to estimate advantages response_mask = data.batch["response_mask"]
if adv_estimator == "gae": index = data.non_tensor_batch["uid"]
if adv_estimator == AdvantageEstimator.GAE:
values = data.batch["values"] values = data.batch["values"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
token_level_rewards = data.batch["token_level_rewards"]
advantages, returns = core_algos.compute_gae_advantage_return( advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=token_level_rewards, values=values, eos_mask=response_mask, gamma=gamma, lam=lam token_level_rewards=token_level_rewards, values=values, eos_mask=response_mask, gamma=gamma, lam=lam
) )
data.batch["advantages"] = advantages elif adv_estimator == AdvantageEstimator.GRPO:
data.batch["returns"] = returns
elif adv_estimator == "grpo":
token_level_rewards = data.batch["token_level_rewards"]
index = data.non_tensor_batch["uid"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
advantages, returns = core_algos.compute_grpo_outcome_advantage( advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index
) )
data.batch["advantages"] = advantages elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
data.batch["returns"] = returns
elif adv_estimator == "reinforce_plus_plus":
token_level_rewards = data.batch["token_level_rewards"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma
) )
data.batch["advantages"] = advantages elif adv_estimator == AdvantageEstimator.REMAX:
data.batch["returns"] = returns
elif adv_estimator == "remax":
token_level_rewards = data.batch["token_level_rewards"]
index = data.non_tensor_batch["uid"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
reward_baselines = data.batch["reward_baselines"] reward_baselines = data.batch["reward_baselines"]
advantages, returns = core_algos.compute_remax_outcome_advantage( advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards=token_level_rewards, reward_baselines=reward_baselines, eos_mask=response_mask token_level_rewards=token_level_rewards, reward_baselines=reward_baselines, eos_mask=response_mask
) )
elif adv_estimator == AdvantageEstimator.RLOO:
data.batch["advantages"] = advantages advantages, returns = core_algos.compute_rloo_outcome_advantage(
data.batch["returns"] = returns token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index
)
else: else:
raise NotImplementedError raise NotImplementedError
return data
data.batch["advantages"] = advantages
def reduce_metrics(metrics: Dict[str, Any]): data.batch["returns"] = returns
for key, val in metrics.items(): return data
metrics[key] = np.mean(val)
return metrics
def _compute_response_info(batch: DataProto):
response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-response_length]
response_mask = batch.batch["attention_mask"][:, -response_length:]
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)
return dict(
response_mask=response_mask,
prompt_length=prompt_length,
response_length=response_length,
)
def compute_data_metrics(batch: DataProto, use_critic: bool = True):
# TODO: add response length
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
max_response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
num_response_tokens = torch.sum(response_info["response_length"]).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens
num_tokens_of_section = {
"gen": num_response_tokens,
**{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
@contextmanager @contextmanager
...@@ -308,8 +201,6 @@ class RayPPOTrainer: ...@@ -308,8 +201,6 @@ class RayPPOTrainer:
Note that this trainer runs on the driver process on a single CPU/GPU node. Note that this trainer runs on the driver process on a single CPU/GPU node.
""" """
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__( def __init__(
self, self,
config: PPOConfig, config: PPOConfig,
...@@ -318,8 +209,8 @@ class RayPPOTrainer: ...@@ -318,8 +209,8 @@ class RayPPOTrainer:
role_worker_mapping: dict[Role, WorkerType], role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager, resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn=None, reward_fn: Callable = None,
val_reward_fn=None, val_reward_fn: Callable = None,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
...@@ -328,42 +219,51 @@ class RayPPOTrainer: ...@@ -328,42 +219,51 @@ class RayPPOTrainer:
self.val_reward_fn = val_reward_fn self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.worker.hybrid_engine self.hybrid_engine = config.worker.hybrid_engine
assert self.hybrid_engine, "Currently, only support hybrid engine"
if self.hybrid_engine: if self.hybrid_engine:
assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()}" assert Role.ActorRollout in role_worker_mapping, (
f"ActorRollout should be included in {role_worker_mapping.keys()}."
)
else:
raise NotImplementedError
self.role_worker_mapping = role_worker_mapping self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_reward_model = Role.RewardModel in role_worker_mapping self.use_reward_model = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls self.ray_worker_group_cls = ray_worker_group_cls
# define KL control # define KL control
if self.use_reference_policy: if Role.RefPolicy in role_worker_mapping and not config.algorithm.disable_kl:
self.use_reference_policy = True
self.kl_ctrl = core_algos.get_kl_controller(config.algorithm) self.kl_ctrl = core_algos.get_kl_controller(config.algorithm)
else: else:
self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.0) self.use_reference_policy = False
self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0)
print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.")
if self.config.algorithm.adv_estimator == "gae": if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True self.use_critic = True
elif self.config.algorithm.adv_estimator == "grpo":
self.use_critic = False
elif self.config.algorithm.adv_estimator == "reinforce_plus_plus":
self.use_critic = False
elif self.config.algorithm.adv_estimator == "remax":
self.use_critic = False
else: else:
raise NotImplementedError self.use_critic = False
if config.algorithm.adv_estimator not in list(AdvantageEstimator):
raise NotImplementedError(f"Unknown advantage estimator: {config.algorithm.adv_estimator}.")
if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by global batch size.")
if self.use_critic and config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by global batch size.")
self._create_dataloader() self._create_dataloader()
def _create_dataloader(self): def _create_dataloader(self) -> None:
self.train_dataset = RLHFDataset( self.train_dataset = RLHFDataset(
data_path=self.config.data.train_files, data_path=self.config.data.train_files,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
processor=self.processor, processor=self.processor,
prompt_key=self.config.data.prompt_key, prompt_key=self.config.data.prompt_key,
answer_key=self.config.data.answer_key,
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length, max_prompt_length=self.config.data.max_prompt_length,
truncation="right", truncation="right",
system_prompt=self.config.data.system_prompt, system_prompt=self.config.data.system_prompt,
...@@ -378,13 +278,14 @@ class RayPPOTrainer: ...@@ -378,13 +278,14 @@ class RayPPOTrainer:
else: else:
sampler = SequentialSampler(data_source=self.train_dataset) sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = DataLoader( self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset, dataset=self.train_dataset,
batch_size=self.config.data.rollout_batch_size, batch_size=self.config.data.rollout_batch_size,
sampler=sampler,
num_workers=8, num_workers=8,
drop_last=True,
collate_fn=collate_fn, collate_fn=collate_fn,
sampler=sampler, pin_memory=False,
drop_last=True,
) )
self.val_dataset = RLHFDataset( self.val_dataset = RLHFDataset(
...@@ -392,24 +293,28 @@ class RayPPOTrainer: ...@@ -392,24 +293,28 @@ class RayPPOTrainer:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
processor=self.processor, processor=self.processor,
prompt_key=self.config.data.prompt_key, prompt_key=self.config.data.prompt_key,
answer_key=self.config.data.answer_key,
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length, max_prompt_length=self.config.data.max_prompt_length,
truncation="right", truncation="right",
system_prompt=self.config.data.system_prompt, system_prompt=self.config.data.system_prompt,
min_pixels=self.config.data.min_pixels, min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels, max_pixels=self.config.data.max_pixels,
) )
self.val_dataloader = DataLoader( self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset, dataset=self.val_dataset,
batch_size=len(self.val_dataset), batch_size=len(self.val_dataset)
num_workers=8, if self.config.data.val_batch_size == -1
else self.config.data.val_batch_size,
shuffle=False, shuffle=False,
drop_last=False, num_workers=8,
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
) )
assert len(self.train_dataloader) >= 1 assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1 assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}") print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}") print(f"Size of val dataloader: {len(self.val_dataloader)}")
...@@ -423,20 +328,11 @@ class RayPPOTrainer: ...@@ -423,20 +328,11 @@ class RayPPOTrainer:
self.config.worker.critic.optim.training_steps = training_steps self.config.worker.critic.optim.training_steps = training_steps
print(f"Total training steps: {self.training_steps}") print(f"Total training steps: {self.training_steps}")
def _maybe_log_val_generations_to_wandb(self, inputs, outputs, scores): def _maybe_log_val_generations(self, inputs: List[str], outputs: List[str], scores: List[float]) -> None:
"""Log a table of validation samples to wandb""" """Log a table of validation samples"""
if self.config.trainer.val_generations_to_log <= 0:
generations_to_log = self.config.trainer.val_generations_to_log_to_wandb
if generations_to_log == 0:
return return
if generations_to_log > 0 and "wandb" not in self.config.trainer.logger:
print("WARNING: `val_generations_to_log_to_wandb` is set, but no wandb logger is found.")
return
import wandb
# Create tuples of (input, output, score) and sort by input text # Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores)) samples = list(zip(inputs, outputs, scores))
samples.sort(key=lambda x: x[0]) # Sort by input text samples.sort(key=lambda x: x[0]) # Sort by input text
...@@ -445,43 +341,14 @@ class RayPPOTrainer: ...@@ -445,43 +341,14 @@ class RayPPOTrainer:
rng = np.random.RandomState(42) rng = np.random.RandomState(42)
rng.shuffle(samples) rng.shuffle(samples)
# Take first N samples after shuffling samples = samples[: self.config.trainer.val_generations_to_log]
samples = samples[:generations_to_log] self.logger.log_generation(samples, self.global_step)
# Create column names for all samples def _validate(self) -> Dict[str, Any]:
columns = ["step"] + sum(
[[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
)
if not hasattr(self, "validation_table"):
# Initialize the table on first call
self.validation_table = wandb.Table(columns=columns)
# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self.validation_table.data)
# Add new row with all data
row_data = []
row_data.append(self.global_steps)
for sample in samples:
row_data.extend(sample)
new_table.add_data(*row_data)
# Update reference and log
wandb.log({"val/generations": new_table}, step=self.global_steps)
self.validation_table = new_table
def _validate(self):
reward_tensor_lst = [] reward_tensor_lst = []
data_source_lst = []
# Lists to collect samples for the table # Lists to collect samples for the table
sample_inputs = [] sample_inputs, sample_outputs, sample_scores = [], [], []
sample_outputs = [] reward_metrics_lst = defaultdict(list)
sample_scores = []
for test_data in self.val_dataloader: for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data) test_batch = DataProto.from_single_dict(test_data)
# Store original inputs # Store original inputs
...@@ -489,10 +356,10 @@ class RayPPOTrainer: ...@@ -489,10 +356,10 @@ class RayPPOTrainer:
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts) sample_inputs.extend(input_texts)
if "pixel_values" in test_batch.non_tensor_batch.keys(): if "multi_modal_inputs" in test_batch.non_tensor_batch.keys():
test_gen_batch = test_batch.pop( test_gen_batch = test_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"], batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["pixel_values", "image_grid_thw", "raw_prompt_ids", "images"], non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
) )
else: else:
test_gen_batch = test_batch.pop( test_gen_batch = test_batch.pop(
...@@ -500,15 +367,10 @@ class RayPPOTrainer: ...@@ -500,15 +367,10 @@ class RayPPOTrainer:
non_tensor_batch_keys=["raw_prompt_ids"], non_tensor_batch_keys=["raw_prompt_ids"],
) )
test_gen_batch.meta_info = {"do_sample": False} test_gen_batch.meta_info = self.config.worker.rollout.val_override_config
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
# pad to be divisible by dp_size test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor( test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
test_gen_batch, self.actor_rollout_wg.world_size
)
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print("validation generation end") print("validation generation end")
# Store generated outputs # Store generated outputs
...@@ -519,40 +381,24 @@ class RayPPOTrainer: ...@@ -519,40 +381,24 @@ class RayPPOTrainer:
test_batch = test_batch.union(test_output_gen_batch) test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function # evaluate using reward_function
reward_tensor = self.val_reward_fn(test_batch) reward_tensor, reward_metrics = self.val_reward_fn(test_batch)
# Store scores # Store scores
scores = reward_tensor.sum(-1).cpu().tolist() scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores) sample_scores.extend(scores)
reward_tensor_lst.append(reward_tensor) reward_tensor_lst.append(reward_tensor)
data_source_lst.append( for key, value in reward_metrics.items():
test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]) reward_metrics_lst[key].extend(value)
)
self._maybe_log_val_generations_to_wandb(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item()
val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()}
return {"val/reward_score": reward_score, **val_reward_metrics}
reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) def init_workers(self) -> None:
data_sources = np.concatenate(data_source_lst, axis=0)
# evaluate test_score based on data source
data_source_reward = {}
for i in range(reward_tensor.shape[0]):
data_source = data_sources[i]
if data_source not in data_source_reward:
data_source_reward[data_source] = []
data_source_reward[data_source].append(reward_tensor[i].item())
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards)
return metric_dict
def init_workers(self):
"""Init resource pool and worker group""" """Init resource pool and worker group"""
self.resource_pool_manager.create_resource_pool() self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
# create actor and rollout # create actor and rollout
...@@ -594,7 +440,7 @@ class RayPPOTrainer: ...@@ -594,7 +440,7 @@ class RayPPOTrainer:
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size, # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {} all_wg: Dict[str, FSDPWorker] = {}
self.wg_dicts = [] self.wg_dicts = []
for resource_pool, class_dict in self.resource_pool_to_cls.items(): for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
...@@ -605,62 +451,80 @@ class RayPPOTrainer: ...@@ -605,62 +451,80 @@ class RayPPOTrainer:
self.wg_dicts.append(wg_dict) self.wg_dicts.append(wg_dict)
if self.use_critic: if self.use_critic:
self.critic_wg: FSDPWorker = all_wg["critic"] self.critic_wg = all_wg["critic"]
self.critic_wg.init_model() self.critic_wg.init_model()
if self.use_reference_policy: if self.use_reference_policy:
self.ref_policy_wg: FSDPWorker = all_wg["ref"] self.ref_policy_wg = all_wg["ref"]
self.ref_policy_wg.init_model() self.ref_policy_wg.init_model()
if self.use_reward_model: if self.use_reward_model:
self.rm_wg: FSDPWorker = all_wg["rm"] self.rm_wg = all_wg["rm"]
self.rm_wg.init_model() self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg: FSDPWorker = all_wg["actor_rollout"] self.actor_rollout_wg = all_wg["actor_rollout"]
self.actor_rollout_wg.init_model() self.actor_rollout_wg.init_model()
def _save_checkpoint(self): def _save_checkpoint(self) -> None:
# path: {save_checkpoint_path}/global_step_{global_steps}/actor # path: {save_checkpoint_path}/global_step_{global_step}/{actor,critic}
local_global_step_folder = os.path.join( remove_obsolete_ckpt(
self.config.trainer.save_checkpoint_path, f"global_step_{self.global_steps}" self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit
)
actor_local_path = os.path.join(local_global_step_folder, "actor")
self.actor_rollout_wg.save_checkpoint(
actor_local_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt,
) )
folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}")
actor_path = os.path.join(folder_path, "actor")
self.actor_rollout_wg.save_checkpoint(actor_path)
if self.use_critic: if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, "critic") critic_path = os.path.join(folder_path, "critic")
self.critic_wg.save_checkpoint( self.critic_wg.save_checkpoint(critic_path)
critic_local_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt,
)
local_latest_checkpointed_iteration = os.path.join( dataloader_path = os.path.join(folder_path, "dataloader.pt")
self.config.trainer.save_checkpoint_path, "latest_checkpointed_iteration.txt" dataloader_state_dict = self.train_dataloader.state_dict()
) torch.save(dataloader_state_dict, dataloader_path)
with open(local_latest_checkpointed_iteration, "w") as f:
f.write(str(self.global_steps))
def _load_checkpoint(self): last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER)
with open(last_global_step_path, "w") as f:
f.write(str(self.global_step))
def _load_checkpoint(self) -> None:
if self.config.trainer.load_checkpoint_path is None: if self.config.trainer.load_checkpoint_path is None:
return return
print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}") if "global_step_" not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]:
raise ValueError("`load_checkpoint_path` should end with `global_step_*`.")
print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}.")
self.global_step = int(self.config.trainer.load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1])
actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor") actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor")
critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic") self.actor_rollout_wg.load_checkpoint(actor_path)
self.actor_rollout_wg.load_checkpoint(
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
if self.use_critic: if self.use_critic:
self.critic_wg.load_checkpoint( critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic")
critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load self.critic_wg.load_checkpoint(critic_path)
)
dataloader_path = os.path.join(self.config.trainer.load_checkpoint_path, "dataloader.pt")
if os.path.exists(dataloader_path):
dataloader_state_dict = torch.load(dataloader_path, weights_only=False)
self.train_dataloader.load_state_dict(dataloader_state_dict)
else:
print(f"No dataloader state found at {dataloader_path}, will start from scratch.")
def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None:
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch["attention_mask"]
batch_size = attention_mask.shape[0]
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
world_size = self.actor_rollout_wg.world_size
global_partition_lst = get_seqlen_balanced_partitions(
global_seqlen_lst, k_partitions=world_size, equal_size=True
)
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
batch.reorder(global_idx)
global_balance_stats = log_seqlen_unbalance(
seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
)
metrics.update(global_balance_stats)
def fit(self): def fit(self):
""" """
...@@ -668,13 +532,9 @@ class RayPPOTrainer: ...@@ -668,13 +532,9 @@ class RayPPOTrainer:
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process. The light-weight advantage computation is done on the driver process.
""" """
logger = Tracking( self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict())
project_name=self.config.trainer.project_name, self.global_step = 0
experiment_name=self.config.trainer.experiment_name, val_metrics: Optional[Dict[str, Any]] = None
default_backend=self.config.trainer.logger,
config=self.config.to_dict(),
)
self.global_steps = 0
# load checkpoint before doing anything # load checkpoint before doing anything
self._load_checkpoint() self._load_checkpoint()
...@@ -683,27 +543,24 @@ class RayPPOTrainer: ...@@ -683,27 +543,24 @@ class RayPPOTrainer:
# currently, we only support validation using the reward_function. # currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.val_before_train: if self.val_reward_fn is not None and self.config.trainer.val_before_train:
val_metrics = self._validate() val_metrics = self._validate()
pprint(f"Initial validation metrics: {val_metrics}") self.logger.log(data=val_metrics, step=self.global_step)
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.val_only: if self.config.trainer.val_only:
return return
for _ in range(self.config.trainer.total_episodes): for _ in tqdm(range(self.config.trainer.total_episodes), desc="Episode", position=0):
for batch_dict in self.train_dataloader: for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1):
self.global_steps += 1 self.global_step += 1
if self.global_steps >= self.training_steps: if self.global_step > self.training_steps:
break break
metrics = {} metrics, timing_raw = {}, {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict) batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation # pop those keys for generation
if "pixel_values" in batch.non_tensor_batch.keys(): if "multi_modal_inputs" in batch.non_tensor_batch.keys():
gen_batch = batch.pop( gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"], batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["pixel_values", "image_grid_thw", "raw_prompt_ids", "images"], non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
) )
else: else:
gen_batch = batch.pop( gen_batch = batch.pop(
...@@ -719,17 +576,15 @@ class RayPPOTrainer: ...@@ -719,17 +576,15 @@ class RayPPOTrainer:
if self.config.algorithm.adv_estimator == "remax": if self.config.algorithm.adv_estimator == "remax":
with _timer("gen_max", timing_raw): with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False gen_baseline_batch.meta_info["temperature"] = 0.0
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output) batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch) reward_baseline_tensor, _ = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch["uid"] = np.array( batch.non_tensor_batch["uid"] = np.array(
...@@ -739,24 +594,37 @@ class RayPPOTrainer: ...@@ -739,24 +594,37 @@ class RayPPOTrainer:
batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True) batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
batch = batch.union(gen_batch_output) batch = batch.union(gen_batch_output)
# compute reward
with _timer("reward", timing_raw):
if self.use_reward_model:
raise NotImplementedError("Reward model is not supported yet.")
# we combine with rule-based rm
reward_tensor, reward_metrics = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {
f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items()
}
metrics.update(reward_metrics)
# balance the number of valid tokens on each dp rank. # balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch. # Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo # Please take care when you implement group based adv computation such as GRPO and rloo
# self._balance_batch(batch, metrics=metrics) # TODO: re-enable balance batch self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens # compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs # recompute old_log_probs
with _timer("old_log_prob", timing_raw): with _timer("old", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
batch = batch.union(old_log_prob) batch = batch.union(old_log_probs)
# compute ref_log_probs
if self.use_reference_policy: if self.use_reference_policy:
# compute reference log_prob
with _timer("ref", timing_raw): with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_prob) batch = batch.union(ref_log_probs)
# compute values # compute values
if self.use_critic: if self.use_critic:
...@@ -765,18 +633,8 @@ class RayPPOTrainer: ...@@ -765,18 +633,8 @@ class RayPPOTrainer:
batch = batch.union(values) batch = batch.union(values)
with _timer("adv", timing_raw): with _timer("adv", timing_raw):
# compute scores. Support both model and function-based. # apply kl penalty if available
# We first compute the scores using reward model. Then, we call reward_fn to combine if not self.config.algorithm.use_kl_loss and self.use_reference_policy: # apply kl penalty to reward
# the results from reward model and rule-based results.
if self.use_reward_model:
raise NotImplementedError
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
# compute rewards. apply_kl_penalty if available
if not self.config.worker.actor.use_kl_loss: # not grpo
batch, kl_metrics = apply_kl_penalty( batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
) )
...@@ -790,7 +648,6 @@ class RayPPOTrainer: ...@@ -790,7 +648,6 @@ class RayPPOTrainer:
adv_estimator=self.config.algorithm.adv_estimator, adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma, gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam, lam=self.config.algorithm.lam,
num_repeat=self.config.worker.rollout.n,
) )
# update critic # update critic
...@@ -798,43 +655,51 @@ class RayPPOTrainer: ...@@ -798,43 +655,51 @@ class RayPPOTrainer:
with _timer("update_critic", timing_raw): with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch) critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
metrics.update(critic_output_metrics) metrics.update(critic_metrics)
# implement critic warmup # update actor
if self.config.trainer.critic_warmup <= self.global_steps: if self.config.trainer.critic_warmup <= self.global_step:
# update actor
with _timer("update_actor", timing_raw): with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch) actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
metrics.update(actor_output_metrics) metrics.update(actor_metrics)
# validate # validate
if ( if (
self.val_reward_fn is not None self.val_reward_fn is not None
and self.config.trainer.test_freq > 0 and self.config.trainer.val_freq > 0
and self.global_steps % self.config.trainer.test_freq == 0 and self.global_step % self.config.trainer.val_freq == 0
): ):
with _timer("testing", timing_raw): with _timer("validation", timing_raw):
val_metrics: dict = self._validate() val_metrics = self._validate()
metrics.update(val_metrics) metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
with _timer("save_checkpoint", timing_raw): with _timer("save_checkpoint", timing_raw):
self._save_checkpoint() self._save_checkpoint()
# collect metrics # collect metrics
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# TODO: make a canonical logger that supports various backend self.logger.log(data=metrics, step=self.global_step)
logger.log(data=metrics, step=self.global_steps)
# perform validation after training # perform validation after training
if self.val_reward_fn is not None: if self.val_reward_fn is not None:
val_metrics = self._validate() if (
pprint(f"Final validation metrics: {val_metrics}") val_metrics is None
logger.log(data=val_metrics, step=self.global_steps) or self.config.trainer.val_freq <= 0
or self.global_step % self.config.trainer.val_freq != 0
self._save_checkpoint() ):
val_metrics = self._validate()
self.logger.log(data=val_metrics, step=self.global_step)
print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}")
if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0:
self._save_checkpoint()
...@@ -11,8 +11,3 @@ ...@@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .tokenizer import get_processor, get_tokenizer
__all__ = ["get_processor", "get_tokenizer"]
...@@ -11,3 +11,8 @@ ...@@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt
__all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"]
...@@ -11,20 +11,27 @@ ...@@ -11,20 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import random import random
import re
import shutil import shutil
import tempfile import tempfile
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed import torch.distributed as dist
from filelock import FileLock from filelock import FileLock
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
class BaseCheckpointManager: CHECKPOINT_TRACKER = "latest_global_step.txt"
class BaseCheckpointManager(ABC):
""" """
A checkpoint manager that saves and loads A checkpoint manager that saves and loads
- model - model
...@@ -44,42 +51,27 @@ class BaseCheckpointManager: ...@@ -44,42 +51,27 @@ class BaseCheckpointManager:
model: FSDP, model: FSDP,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer, processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
processor: ProcessorMixin
): ):
self.previous_global_step = None
self.previous_save_local_path = None
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.tokenizer = tokenizer self.processing_class = processing_class
self.processor = processor
assert isinstance(self.model, FSDP) assert isinstance(self.model, FSDP)
self.rank = torch.distributed.get_rank() self.rank = dist.get_rank()
self.world_size = torch.distributed.get_world_size() self.world_size = dist.get_world_size()
@abstractmethod
def load_checkpoint(self, *args, **kwargs): def load_checkpoint(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@abstractmethod
def save_checkpoint(self, *args, **kwargs): def save_checkpoint(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def remove_previous_save_local_path(self):
if not self.previous_save_local_path:
return
abs_path = os.path.abspath(self.previous_save_local_path)
print(f"Checkpoint manager remove previous save local path: {abs_path}")
if not os.path.exists(abs_path):
return
# remove previous local_path
shutil.rmtree(abs_path, ignore_errors=True)
@staticmethod @staticmethod
def local_mkdir(path): def local_mkdir(path: str) -> str:
if not os.path.isabs(path): if not os.path.isabs(path):
working_dir = os.getcwd() working_dir = os.getcwd()
path = os.path.join(working_dir, path) path = os.path.join(working_dir, path)
...@@ -89,18 +81,16 @@ class BaseCheckpointManager: ...@@ -89,18 +81,16 @@ class BaseCheckpointManager:
lock_path = os.path.join(tempfile.gettempdir(), lock_filename) lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
try: try:
with FileLock(lock_path, timeout=60): # Add timeout with FileLock(lock_path, timeout=60):
# make a new dir
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
except Exception as e: except Exception as e:
print(f"Warning: Failed to acquire lock for {path}: {e}") print(f"Warning: Failed to acquire lock for {path}: {e}")
# Even if the lock is not acquired, try to create the directory os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory
os.makedirs(path, exist_ok=True)
return path return path
@staticmethod @staticmethod
def get_rng_state(): def get_rng_state() -> Dict[str, Any]:
rng_state = { rng_state = {
"cpu": torch.get_rng_state(), "cpu": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state(), "cuda": torch.cuda.get_rng_state(),
...@@ -110,14 +100,14 @@ class BaseCheckpointManager: ...@@ -110,14 +100,14 @@ class BaseCheckpointManager:
return rng_state return rng_state
@staticmethod @staticmethod
def load_rng_state(rng_state): def load_rng_state(rng_state: Dict[str, Any]):
torch.set_rng_state(rng_state["cpu"]) torch.set_rng_state(rng_state["cpu"])
torch.cuda.set_rng_state(rng_state["cuda"]) torch.cuda.set_rng_state(rng_state["cuda"])
np.random.set_state(rng_state["numpy"]) np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["random"]) random.setstate(rng_state["random"])
def find_latest_ckpt_path(path, directory_format="global_step_{}"): def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]:
if path is None: if path is None:
return None return None
...@@ -128,6 +118,7 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"): ...@@ -128,6 +118,7 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
with open(tracker_file, "rb") as f: with open(tracker_file, "rb") as f:
iteration = int(f.read().decode()) iteration = int(f.read().decode())
ckpt_path = os.path.join(path, directory_format.format(iteration)) ckpt_path = os.path.join(path, directory_format.format(iteration))
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
print("Checkpoint does not exist: %s", ckpt_path) print("Checkpoint does not exist: %s", ckpt_path)
...@@ -137,8 +128,33 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"): ...@@ -137,8 +128,33 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
return ckpt_path return ckpt_path
def get_checkpoint_tracker_filename(root_path: str): def get_checkpoint_tracker_filename(root_path: str) -> str:
""" """
Tracker file rescords the latest chckpoint during training to restart from. Tracker file rescords the latest chckpoint during training to restart from.
""" """
return os.path.join(root_path, "latest_checkpointed_iteration.txt") return os.path.join(root_path, CHECKPOINT_TRACKER)
def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"):
"""
Remove the obsolete checkpoints that exceed the save_limit.
"""
if save_limit <= 0:
return
if not os.path.exists(path):
return
pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)")
ckpt_folders = []
for folder in os.listdir(path):
if match := re.match(pattern, folder):
step = int(match.group(1))
if step < global_step:
ckpt_folders.append((step, folder))
ckpt_folders.sort(reverse=True)
for _, folder in ckpt_folders[save_limit - 1 :]:
folder_path = os.path.join(path, folder)
shutil.rmtree(folder_path, ignore_errors=True)
print(f"Removed obsolete checkpoint: {folder_path}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment