Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
......@@ -12,11 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from setuptools import find_packages, setup
def get_requires():
def get_version() -> str:
with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"__version__\W*=\W*\"([^\"]+)\""
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
......@@ -31,18 +41,19 @@ extra_require = {
def main():
setup(
name="verl",
version="0.2.0.dev0",
package_dir={"": "."},
packages=find_packages(where="."),
url="https://github.com/volcengine/verl",
license="Apache 2.0",
version=get_version(),
description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
author="verl",
author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
description="",
license="Apache 2.0 License",
url="https://github.com/volcengine/verl",
package_dir={"": "."},
packages=find_packages(where="."),
python_requires=">=3.9.0",
install_requires=get_requires(),
extras_require=extra_require,
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
)
......
......@@ -12,8 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .protocol import DataProto
__all__ = ["DataProto"]
__version__ = "0.2.0.dev"
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .transformers.flash_attention_utils import flash_attention_forward
from .transformers.qwen2_vl import qwen2_vl_attn_forward
def apply_ulysses_patch(model_type: str) -> None:
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
else:
raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
from ...utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
_flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def prepare_fa2_from_position_ids(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seqlens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
def _custom_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
query_length: int,
is_causal: bool = True,
position_ids: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
deterministic: Optional[bool] = None,
**kwargs,
):
"""
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if _flash_supports_deterministic:
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
if kwargs.get("softcap") is not None:
flash_kwargs["softcap"] = kwargs.pop("softcap")
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype=torch.bfloat16
)
sp_size = get_ulysses_sequence_parallel_world_size()
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids[0]
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=kwargs.pop("dropout", 0.0),
softmax_scale=kwargs.pop("softmax_scale", None),
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_length,
is_causal=is_causal,
sliding_window=sliding_window,
use_top_left_mask=use_top_left_mask,
deterministic=deterministic,
**kwargs,
) # do not pass position_ids to old flash_attention_forward
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
return attn_output
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# This is before the transpose
q_len = query.shape[2]
# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
attn_output = _custom_flash_attention_forward(
query,
key,
value,
attention_mask,
query_length=q_len,
is_causal=True,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_flash_use_top_left_mask,
**kwargs,
)
return attn_output, None
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Based on:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from .flash_attention_utils import flash_attention_forward
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLAttention,
apply_multimodal_rotary_pos_emb,
repeat_kv,
)
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
except ImportError:
pass
def get_rope_index(
processor: "Qwen2VLProcessor",
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
video_grid_thw: Optional[torch.Tensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
"""
spatial_merge_size = processor.image_processor.merge_size
tokens_per_second = 2
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
image_index, video_index = 0, 0
input_ids = input_ids[attention_mask == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
else:
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
return position_ids
def qwen2_vl_attn_forward(
self: "Qwen2VLAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, None, None]:
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
attn_output, _ = flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=dropout_rate,
sliding_window=sliding_window,
position_ids=position_ids, # important: pass position ids
) # (batch_size, seq_length, num_head / sp_size, head_size)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None, None
......@@ -26,13 +26,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import ray
import torch
import torch.distributed as dist
from numpy.typing import NDArray
from tensordict import TensorDict
from torch.distributed import ProcessGroup
from torch.utils.data import DataLoader
from verl.utils.py_functional import union_two_dict
from .utils.py_functional import union_two_dict
try:
......@@ -89,21 +88,22 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
)
for key, value in tensor_dict2.items():
if key in tensor_dict1 and not torch.equal(tensor_dict1[key], value):
for key in tensor_dict2.keys():
if key in tensor_dict1 and not torch.equal(tensor_dict1[key], tensor_dict2[key]):
raise ValueError(f"Key already exists: {key}.")
tensor_dict1[key] = value
tensor_dict1[key] = tensor_dict2[key]
return tensor_dict1
def union_numpy_dict(
tensor_dict1: Dict[str, Union[List, NDArray]], tensor_dict2: Dict[str, Union[List, NDArray]]
) -> Dict[str, Union[List, NDArray]]:
for key, value in tensor_dict2.items():
if key in tensor_dict1 and isinstance(value, np.ndarray) and not np.all(tensor_dict1[key] == value):
raise ValueError(f"Key already exists: {key}.")
def union_numpy_dict(tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray]) -> Dict[str, NDArray]:
for key in tensor_dict2.keys():
if key in tensor_dict1:
assert isinstance(tensor_dict2[key], np.ndarray)
assert isinstance(tensor_dict1[key], np.ndarray)
if not np.all(tensor_dict1[key] == tensor_dict2[key]):
raise ValueError(f"Key already exists: {key}.")
tensor_dict1[key] = tensor_dict2[key]
......@@ -151,6 +151,7 @@ def collate_fn(data_items: list["DataProtoItem"]):
batch = torch.stack(batch).contiguous()
non_tensor_batch = batch_collate(non_tensor_batch)
non_tensor_batch = {key: np.array(value, dtype=object) for key, value in non_tensor_batch.items()}
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
......@@ -189,7 +190,8 @@ class DataProto:
def __getitem__(self, item):
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self):
buffer = io.BytesIO()
......@@ -229,8 +231,7 @@ class DataProto:
size_of_numpy_array = 0
for value in self.non_tensor_batch.values():
if isinstance(value, np.ndarray):
size_of_numpy_array += value.nbytes
size_of_numpy_array += value.nbytes
size_of_numpy_array /= 1024**3
size_of_tensordict /= 1024**3
......@@ -254,13 +255,13 @@ class DataProto:
assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}."
@classmethod
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, list, NDArray]], meta_info=None):
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, NDArray]], meta_info=None):
tensors = {}
non_tensors = {}
for key, value in data.items():
if isinstance(value, torch.Tensor):
tensors[key] = value
elif isinstance(value, (list, np.ndarray)):
elif isinstance(value, np.ndarray):
non_tensors[key] = value
else:
raise ValueError(f"Unsupported type in data {type(value)}")
......@@ -472,8 +473,6 @@ class DataProto:
assert len(self) % chunks == 0, (
f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}."
)
chunk_size = len(self) // chunks
if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
else:
......@@ -481,12 +480,8 @@ class DataProto:
non_tensor_batch_lst = [{} for _ in range(chunks)]
for key, value in self.non_tensor_batch.items():
assert isinstance(value, (list, np.ndarray))
if isinstance(value, np.ndarray):
non_tensor_lst = np.array_split(value, chunks)
else:
non_tensor_lst = [value[i : i + chunk_size] for i in range(0, len(self), chunk_size)]
assert isinstance(value, np.ndarray)
non_tensor_lst = np.array_split(value, chunks)
assert len(non_tensor_lst) == chunks
for i in range(chunks):
non_tensor_batch_lst[i][key] = non_tensor_lst[i]
......@@ -524,12 +519,7 @@ class DataProto:
non_tensor_batch = batch_collate([d.non_tensor_batch for d in data])
for key, value in non_tensor_batch.items():
if isinstance(value[0], np.ndarray):
non_tensor_batch[key] = np.concatenate(value, axis=0)
else:
non_tensor_batch[key] = []
for item in value:
non_tensor_batch[key].extend(item)
non_tensor_batch[key] = np.concatenate(value, axis=0)
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
......@@ -574,16 +564,10 @@ class DataProto:
repeated_non_tensor_batch = {}
for key, value in self.non_tensor_batch.items():
if isinstance(value, np.ndarray):
if interleave:
repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
else:
repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
if interleave:
repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
else:
if interleave:
repeated_non_tensor_batch[key] = [item for item in value for _ in range(repeat_times)]
else:
repeated_non_tensor_batch[key] = [item for _ in range(repeat_times) for item in value]
repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
return DataProto(
batch=repeated_batch,
......@@ -591,39 +575,6 @@ class DataProto:
meta_info=self.meta_info,
)
def broadcast(self, src: int, group: Optional[ProcessGroup] = None):
for key in self.batch.sorted_keys:
dist.broadcast(self.batch[key], src=src, group=group, async_op=False)
object_list = [self.non_tensor_batch]
dist.broadcast_object_list(object_list, src=src, group=group)
self.non_tensor_batch = object_list[0]
def all_gather(self, group: Optional[ProcessGroup] = None):
world_size = dist.get_world_size(group)
output = {}
for key in self.batch.sorted_keys:
value = self.batch[key].contiguous()
output[key] = [torch.empty_like(value) for _ in range(world_size)]
dist.all_gather(output[key], value, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=0)
self.batch = TensorDict(output, batch_size=self.batch.batch_size[0] * world_size)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(world_size)]
dist.all_gather_object(all_non_tensor_batch, self.non_tensor_batch, group=group)
non_tensor_batch = defaultdict(list)
for key, value in self.non_tensor_batch.items():
if isinstance(value, np.ndarray):
non_tensor_batch[key] = np.concatenate([batch[key] for batch in all_non_tensor_batch])
else:
for batch in all_non_tensor_batch:
non_tensor_batch[key].extend(batch[key])
self.non_tensor_batch = non_tensor_batch
self.check_consistency()
@dataclass
class DataProtoFuture:
......@@ -664,10 +615,53 @@ class DataProtoFuture:
return arg_future_lst
def get(self):
output = ray.get(self.futures) # dp_size.
for o in output:
assert isinstance(o, DataProto)
output = self.collect_fn(output) # select dp, concat
outputs = ray.get(self.futures) # dp_size.
for output in outputs:
assert isinstance(output, DataProto)
outputs = self.collect_fn(outputs) # select dp, concat
if self.dispatch_fn is not None:
output = self.dispatch_fn(output) # split in batch dim, select using dp
return output
outputs = self.dispatch_fn(outputs) # split in batch dim, select using dp
return outputs
def allgather_dict_tensors(
tensors: Union[Dict[str, torch.Tensor], TensorDict], size: int, group: ProcessGroup, dim: int = 0
) -> Union[Dict[str, torch.Tensor], TensorDict]:
"""
TODO: optimize this.
- We can use async ops
- We can use only one allgather
"""
if isinstance(tensors, TensorDict):
is_tensor_dict = True
tensors_as_dict = tensors.to_dict()
else:
tensors_as_dict = tensors
is_tensor_dict = False
output = {}
sorted_keys = sorted(tensors_as_dict.keys())
for key in sorted_keys:
val = tensors_as_dict[key]
output[key] = [torch.empty_like(val) for _ in range(size)]
torch.distributed.all_gather(output[key], val, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=dim)
if is_tensor_dict:
output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)
return output
def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> None:
# Note that this is an inplace operator just like torch.distributed.all_gather
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
......@@ -14,3 +14,6 @@
from .worker import Worker
from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
__all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
......@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from enum import Enum, auto
from functools import wraps
from types import FunctionType
from typing import Dict, List
from typing import TYPE_CHECKING, Dict, List, Literal, Union
import ray
from verl.protocol import DataProto, DataProtoFuture
from ...protocol import DataProto, DataProtoFuture
if TYPE_CHECKING:
from .worker_group import WorkerGroup
# here we add a magic number of avoid user-defined function already have this attribute
......@@ -27,13 +31,13 @@ MAGIC_ATTR = "attrs_3141562937"
class Dispatch(Enum):
RANK_ZERO = 0
ONE_TO_ALL = 1
ALL_TO_ALL = 2
DP_COMPUTE = 3
DP_COMPUTE_PROTO = 4
DP_COMPUTE_PROTO_WITH_FUNC = 5
DP_COMPUTE_METRIC = 6
RANK_ZERO = auto()
ONE_TO_ALL = auto()
ALL_TO_ALL = auto()
DP_COMPUTE = auto()
DP_COMPUTE_PROTO = auto()
DP_COMPUTE_PROTO_WITH_FUNC = auto()
DP_COMPUTE_METRIC = auto()
class Execute(Enum):
......@@ -41,100 +45,85 @@ class Execute(Enum):
RANK_ZERO = 1
def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
splitted_args = []
for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture))
splitted_args.append(arg.chunk(chunks=chunks))
splitted_kwargs = {}
for key, val in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture))
splitted_kwargs[key] = val.chunk(chunks=chunks)
for key, value in kwargs.items():
assert isinstance(value, (DataProto, DataProtoFuture))
splitted_kwargs[key] = value.chunk(chunks=chunks)
return splitted_args, splitted_kwargs
def dispatch_one_to_all(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
args = tuple([arg] * worker_group.world_size for arg in args)
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
return args, kwargs
def dispatch_all_to_all(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
return args, kwargs
def collect_all_to_all(worker_group, output):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def collect_all_to_all(worker_group: "WorkerGroup", output):
return output
def _concat_data_proto_or_future(output: List):
def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
# make sure all the elements in output has the same type
for o in output:
assert type(o) is type(output[0])
for output in outputs:
assert type(output) is type(outputs[0])
o = output[0]
output = outputs[0]
if isinstance(o, DataProto):
return DataProto.concat(output)
elif isinstance(o, ray.ObjectRef):
return DataProtoFuture.concat(output)
if isinstance(output, DataProto):
return DataProto.concat(outputs)
elif isinstance(output, ray.ObjectRef):
return DataProtoFuture.concat(outputs)
else:
raise NotImplementedError
def dispatch_dp_compute(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
return args, kwargs
def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
for arg in args:
assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
for value in kwargs.values():
assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
def collect_dp_compute(worker_group, output):
from verl.single_controller.base.worker_group import WorkerGroup
return args, kwargs
assert isinstance(worker_group, WorkerGroup)
assert len(output) == worker_group.world_size
return output
def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
assert len(outputs) == worker_group.world_size
return outputs
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
return splitted_args, splitted_kwargs
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
return splitted_args_with_func, splitted_kwargs
def collect_dp_compute_data_proto(worker_group, output):
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
for output in outputs:
assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
output = collect_dp_compute(worker_group, output)
return _concat_data_proto_or_future(output)
outputs = collect_dp_compute(worker_group, outputs)
return _concat_data_proto_or_future(outputs)
def get_predefined_dispatch_fn(dispatch_mode):
def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
predefined_dispatch_mode_fn = {
Dispatch.ONE_TO_ALL: {
"dispatch_fn": dispatch_one_to_all,
......@@ -144,6 +133,10 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_all_to_all,
"collect_fn": collect_all_to_all,
},
Dispatch.DP_COMPUTE: {
"dispatch_fn": dispatch_dp_compute,
"collect_fn": collect_dp_compute,
},
Dispatch.DP_COMPUTE_PROTO: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute_data_proto,
......@@ -152,11 +145,15 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_dp_compute_data_proto_with_func,
"collect_fn": collect_dp_compute_data_proto,
},
Dispatch.DP_COMPUTE_METRIC: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute,
},
}
return predefined_dispatch_mode_fn[dispatch_mode]
def get_predefined_execute_fn(execute_mode):
def get_predefined_execute_fn(execute_mode: Execute):
"""
Note that here we only asks execute_all and execute_rank_zero to be implemented
Leave the choice of how these two functions handle argument 'blocking' to users
......@@ -168,17 +165,17 @@ def get_predefined_execute_fn(execute_mode):
return predefined_execute_mode_fn[execute_mode]
def _check_dispatch_mode(dispatch_mode):
assert isinstance(dispatch_mode, (Dispatch, Dict)), (
def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
assert isinstance(dispatch_mode, (Dispatch, dict)), (
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
)
if isinstance(dispatch_mode, Dict):
if isinstance(dispatch_mode, dict):
necessary_keys = ["dispatch_fn", "collect_fn"]
for key in necessary_keys:
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
def _check_execute_mode(execute_mode):
def _check_execute_mode(execute_mode: Execute):
assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
......@@ -189,9 +186,10 @@ def _materialize_futures(*args, **kwargs):
arg = arg.get()
# add more type to materialize
new_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, DataProtoFuture):
kwargs[k] = v.get()
for key, value in kwargs.items():
if isinstance(value, DataProtoFuture):
kwargs[key] = value.get()
new_args = tuple(new_args)
return new_args, kwargs
......
......@@ -18,11 +18,13 @@ the class for Worker
import os
import socket
from dataclasses import dataclass
from typing import Tuple
import ray
import torch
from verl.single_controller.base.decorator import Dispatch, Execute, register
from verl.single_controller.base.register_center.ray import create_worker_group_register_center
from .decorator import Dispatch, Execute, register
from .register_center.ray import create_worker_group_register_center
@dataclass
......@@ -40,7 +42,7 @@ class DistGlobalInfo:
class WorkerHelper:
def _get_node_ip(self):
def _get_node_ip(self) -> str:
host_ipv4 = os.getenv("MY_HOST_IP", None)
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
host_ip_by_env = host_ipv4 or host_ipv6
......@@ -49,12 +51,12 @@ class WorkerHelper:
host_ip = host_ip_by_env or host_ip_by_sdk
return host_ip
def _get_free_port(self):
def _get_free_port(self) -> int:
with socket.socket() as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
def get_availale_master_addr_port(self):
def get_availale_master_addr_port(self) -> Tuple[str, str]:
return self._get_node_ip(), str(self._get_free_port())
def _get_pid(self):
......@@ -81,16 +83,26 @@ class WorkerMeta:
# we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper):
"""A (distributed) worker."""
_world_size: int
_rank: int
_local_world_size: int
_local_rank: int
_master_addr: str
_master_port: str
_cuda_visible_devices: str
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
# note that here we use int to distinguish
disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0))
disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
if disable_worker_init:
return instance
rank = os.environ.get("RANK", None)
worker_group_prefix = os.environ.get("WG_PREFIX", None)
rank = os.getenv("RANK", None)
worker_group_prefix = os.getenv("WG_PREFIX", None)
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
......@@ -112,13 +124,19 @@ class Worker(WorkerHelper):
def __init__(self, cuda_visible_devices=None) -> None:
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.getenv("RANK"))
self._rank = rank
self._world_size = world_size
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
if "AMD" in torch.cuda.get_device_name():
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices))
master_addr = os.getenv("MASTER_ADDR")
master_port = os.getenv("MASTER_PORT")
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
......@@ -149,6 +167,7 @@ class Worker(WorkerHelper):
if val is not None:
# print(f"set {key} to {val}")
os.environ[key] = str(val)
os.environ["REDIS_STORE_SERVER_HOST"] = (
str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
)
......@@ -157,7 +176,7 @@ class Worker(WorkerHelper):
return self._master_addr, self._master_port
def get_cuda_visible_devices(self):
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
return cuda_visible_devices
def print_rank0(self, *args, **kwargs):
......
......@@ -19,20 +19,20 @@ import logging
import signal
import threading
import time
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional
from verl.single_controller.base.decorator import (
MAGIC_ATTR,
Dispatch,
get_predefined_dispatch_fn,
get_predefined_execute_fn,
)
from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
class ResourcePool:
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
"""The resource pool with meta info such as world size."""
def __init__(
self, process_on_nodes: Optional[Any] = None, max_collocate_count: int = 10, n_gpus_per_node: int = 8
) -> None:
if process_on_nodes is None:
process_on_nodes = []
self._store = process_on_nodes
self.max_collocate_count = max_collocate_count
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
......@@ -73,28 +73,23 @@ class ClassWithInitArgs:
self.args = args
self.kwargs = kwargs
# def add_arg(self, arg):
# self.args += (arg,)
# def add_kwarg(self, key, value):
# self.kwargs[key] = value
def __call__(self) -> Any:
return self.cls(*self.args, **self.kwargs)
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
import time
while True:
for worker in workers:
if not is_alive(worker):
logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
signal.raise_signal(signal.SIGABRT)
time.sleep(gap_time)
class WorkerGroup:
"""A group of workers"""
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
self._is_init_with_detached_workers = True if resource_pool is None else False
......@@ -136,14 +131,10 @@ class WorkerGroup:
def world_size(self):
return len(self._workers)
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
# MegatronWorkerGroup, XperfWorkerGroup should skip
def _bind_worker_method(self, user_defined_cls, func_generator):
"""
Bind the worker method to the WorkerGroup
"""
for method_name in dir(user_defined_cls):
try:
method = getattr(user_defined_cls, method_name)
......
......@@ -13,27 +13,28 @@
# limitations under the License.
import os
import random
import re
import string
import time
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch
import ray
from ray.actor import ActorHandle
from ray.experimental.state.api import get_actor
from ray.util import list_named_actors
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy
from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
from verl.single_controller.base.decorator import MAGIC_ATTR
from ..base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
from ..base.decorator import MAGIC_ATTR
__all__ = ["Worker"]
def get_random_string(length: int) -> str:
import random
import string
letters_digits = string.ascii_letters + string.digits
return "".join(random.choice(letters_digits) for _ in range(length))
......@@ -50,6 +51,27 @@ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, block
return func
def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
"""
Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK
to be consistent across nodes when resume from checkpoint.
With this function, if there's only one resource pool and there's no node change, RANK should be consistent
across nodes in multiple ray jobs, even if the whole ray cluster is restarted.
"""
node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
pg_ip = {}
for pg in pgs:
specs = ray._private.state.state.placement_group_table(pg.id)
# all bunles should be on the same node
node_id = specs["bundles_to_node_id"][0]
pg_ip[pg.id] = node_ip[node_id]
return sorted(pgs, key=lambda pg: pg_ip[pg.id])
class RayResourcePool(ResourcePool):
def __init__(
self,
......@@ -57,7 +79,7 @@ class RayResourcePool(ResourcePool):
use_gpu: bool = True,
name_prefix: str = "",
max_colocate_count: int = 5,
detached=False,
detached: bool = False,
) -> None:
super().__init__(process_on_nodes, max_colocate_count)
self.use_gpu = use_gpu
......@@ -66,7 +88,7 @@ class RayResourcePool(ResourcePool):
self.pgs = None
self.detached = detached
def get_placement_groups(self, strategy="STRICT_PACK", name=None):
def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str] = None) -> List[PlacementGroup]:
if self.pgs is not None:
return self.pgs
......@@ -97,7 +119,7 @@ class RayResourcePool(ResourcePool):
def extract_pg_from_exist(
resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool
) -> List:
) -> List[PlacementGroup]:
src_pgs = [
pg
for role_name, resource_pool in resource_pools.items()
......@@ -151,7 +173,12 @@ class RayClassWithInitArgs(ClassWithInitArgs):
self._options.update(options)
def __call__(
self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None
self,
placement_group: PlacementGroup,
placement_group_bundle_idx: int,
use_gpu: bool = True,
num_gpus: int = 1,
sharing_with: Worker = None,
) -> Any:
if sharing_with is not None:
target_node_id = ray.get(sharing_with.get_node_id.remote())
......@@ -188,8 +215,8 @@ class RayWorkerGroup(WorkerGroup):
ray_cls_with_init: RayClassWithInitArgs = None,
bin_pack: bool = True,
name_prefix: str = None,
detached=False,
worker_names=None,
detached: bool = False,
worker_names: List[str] = None,
**kwargs,
) -> None:
super().__init__(resource_pool=resource_pool, **kwargs)
......@@ -210,21 +237,24 @@ class RayWorkerGroup(WorkerGroup):
if ray_cls_with_init is not None:
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
def _is_worker_alive(self, worker: ray.actor.ActorHandle):
def _is_worker_alive(self, worker: ActorHandle) -> bool:
worker_state_dict = get_actor(worker._actor_id.hex())
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
def _init_with_detached_workers(self, worker_names):
def _init_with_detached_workers(self, worker_names: List[str]) -> None:
workers = [ray.get_actor(name=name) for name in worker_names]
self._workers = workers
self._world_size = len(worker_names)
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached):
def _init_with_resource_pool(
self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, bin_pack: bool, detached: bool
):
use_gpu = resource_pool.use_gpu
strategy = "PACK"
if bin_pack:
strategy = "STRICT_PACK"
pgs = resource_pool.get_placement_groups(strategy=strategy)
world_size = resource_pool.world_size
self._world_size = world_size
......@@ -232,8 +262,8 @@ class RayWorkerGroup(WorkerGroup):
num_gpus = 1 / resource_pool.max_collocate_count
rank = -1
for pg_idx, local_world_size in enumerate(resource_pool.store):
pg = pgs[pg_idx]
local_world_size = resource_pool.store[0]
for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):
assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the "
for local_rank in range(local_world_size):
rank += 1
......@@ -251,8 +281,6 @@ class RayWorkerGroup(WorkerGroup):
env_vars["MASTER_ADDR"] = self._master_addr
env_vars["MASTER_PORT"] = self._master_port
import re
cia_name = type(ray_cls_with_init.cls).__name__
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
......@@ -272,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
if rank == 0:
register_center_actor = None
for _ in range(120):
for _ in range(360):
if f"{self.name_prefix}_register_center" not in list_named_actors():
time.sleep(1)
else:
......
......@@ -19,7 +19,7 @@ import os
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Optional, Tuple
from verl.workers.config import WorkerConfig
from ..workers.config import WorkerConfig
def recursive_post_init(dataclass_obj):
......@@ -36,12 +36,13 @@ class DataConfig:
train_files: str = ""
val_files: str = ""
prompt_key: str = "prompt"
answer_key: str = "answer"
image_key: str = "images"
max_prompt_length: int = 512
max_response_length: int = 512
rollout_batch_size: int = 512
return_raw_input_ids: bool = False
return_raw_prompt: bool = False
system_prompt: str = r"Please reason step by step, and put your final answer within \boxed{}."
val_batch_size: int = -1
system_prompt: Optional[str] = None
shuffle: bool = True
seed: int = 1
max_pixels: int = 4194304
......@@ -52,10 +53,12 @@ class DataConfig:
class AlgorithmConfig:
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"
adv_estimator: str = "grpo"
disable_kl: bool = False
use_kl_loss: bool = False
kl_penalty: str = "kl"
kl_type: str = "fixed"
kl_coef: float = 1e-3
kl_type: str = "fixed"
kl_horizon: float = 0.0
kl_target: float = 0.0
......@@ -67,18 +70,17 @@ class TrainerConfig:
project_name: str = "easy_r1"
experiment_name: str = "demo"
logger: Tuple[str] = ("console", "wandb")
val_generations_to_log_to_wandb: int = 0
nnodes: int = 1
n_gpus_per_node: int = 8
save_freq: int = -1
load_checkpoint_path: Optional[str] = None
critic_warmup: int = 0
val_freq: int = -1
val_before_train: bool = True
val_only: bool = False
test_freq: int = -1
critic_warmup: int = 0
remove_previous_ckpt: bool = False
del_local_ckpt_after_load: bool = False
val_generations_to_log: int = 0
save_freq: int = -1
save_limit: int = -1
save_checkpoint_path: Optional[str] = None
load_checkpoint_path: Optional[str] = None
def post_init(self):
if self.save_checkpoint_path is None:
......@@ -95,6 +97,10 @@ class PPOConfig:
def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length
self.worker.actor.disable_kl = self.algorithm.disable_kl
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
self.worker.actor.kl_penalty = self.algorithm.kl_penalty
self.worker.actor.kl_coef = self.algorithm.kl_coef
def deep_post_init(self):
recursive_post_init(self)
......
# Copyright 2022 The HuggingFace Team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,50 +18,55 @@ The function implemented in this file should be used by trainer with different d
implement PPO
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import TYPE_CHECKING, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import verl.utils.torch_functional as verl_F
from ..utils import torch_functional as VF
if TYPE_CHECKING:
from verl.trainer.config import AlgorithmConfig
from .config import AlgorithmConfig
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
class KLController(ABC):
@abstractmethod
def update(self, current_kl: float, n_steps: int) -> None: ...
class AdaptiveKLController(KLController):
"""Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf"""
def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
def update(self, current_kl: float, n_steps: int) -> None:
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
class FixedKLController(KLController):
"""Fixed KL controller."""
def __init__(self, kl_coef: float):
self.value = kl_coef
def __init__(self, init_kl_coef: float):
self.value = init_kl_coef
def update(self, current_kl, n_steps):
def update(self, current_kl: float, n_steps: int) -> None:
pass
def get_kl_controller(algorithm_config: "AlgorithmConfig"):
def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
"""Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319"""
if algorithm_config.kl_type == "fixed":
kl_ctrl = FixedKLController(kl_coef=algorithm_config.kl_coef)
kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
elif algorithm_config.kl_type == "adaptive":
assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
kl_ctrl = AdaptiveKLController(
......@@ -70,19 +75,20 @@ def get_kl_controller(algorithm_config: "AlgorithmConfig"):
horizon=algorithm_config.kl_horizon,
)
else:
raise ValueError("Unknown kl_ctrl type")
raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")
return kl_ctrl
@torch.no_grad()
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
eos_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513
Args:
token_level_rewards: `(torch.Tensor)`
......@@ -103,27 +109,26 @@ def compute_gae_advantage_return(
shape: (bs, response_length)
"""
with torch.no_grad():
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, eos_mask)
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = (advantages + values) * eos_mask
advantages = VF.masked_whiten(advantages, eos_mask) * eos_mask
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@torch.no_grad()
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6
):
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
......@@ -133,6 +138,50 @@ def compute_grpo_outcome_advantage(
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean, id2std = {}, {}
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
@torch.no_grad()
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
......@@ -144,31 +193,33 @@ def compute_grpo_outcome_advantage(
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}.")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
response_num - 1
)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
@torch.no_grad()
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
......@@ -184,26 +235,24 @@ def compute_reinforce_plus_plus_outcome_advantage(
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]
advantages = verl_F.masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]
advantages = VF.masked_whiten(returns, eos_mask)
advantages *= eos_mask
returns *= eos_mask
return advantages, returns
@torch.no_grad()
def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor
):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
......@@ -225,23 +274,31 @@ def compute_remax_outcome_advantage(
"""
response_length = token_level_rewards.shape[-1]
# scores = token_level_rewards.sum(dim=-1)
with torch.no_grad():
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) * eos_mask
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
return advantages, returns
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
kl = old_log_prob - ref_log_prob
def compute_rewards(
token_level_scores: torch.Tensor,
log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
kl_ratio: float,
) -> torch.Tensor:
kl = log_probs - ref_log_probs
return token_level_scores - kl * kl_ratio
def compute_policy_loss(
old_log_prob, log_prob, advantages, eos_mask, cliprange
old_log_probs: torch.Tensor,
log_probs: torch.Tensor,
advantages: torch.Tensor,
eos_mask: torch.Tensor,
cliprange: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
"""Compute the policy loss.
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
Args:
old_log_prob: `(torch.Tensor)`
......@@ -260,95 +317,88 @@ def compute_policy_loss(
policy gradient loss computed via PPO
pg_clipfrac: (float)
a float number indicating the fraction of policy gradient loss being clipped
"""
negative_approx_kl = log_prob - old_log_prob
negative_approx_kl = log_probs - old_log_probs
# clamp the ratio before exp to avoid nan
# see: https://github.com/pytorch/pytorch/issues/10729
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
clipped_ratio = torch.exp(torch.clamp(negative_approx_kl, np.log(1.0 - cliprange), np.log(1.0 + cliprange)))
ppo_kl = VF.masked_mean(-negative_approx_kl, eos_mask)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_losses2 = -advantages * clipped_ratio
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
pg_loss = VF.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = VF.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl
def compute_entropy_loss(logits, eos_mask):
"""Compute Categorical entropy loss
Args:
logits: `(torch.Tensor)`
shape: (bs, response_length, vocab_size)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
return entropy_loss
def compute_value_loss(
vpreds: torch.Tensor,
returns: torch.Tensor,
values: torch.Tensor,
eos_mask: torch.Tensor,
cliprange_value: float,
) -> Tuple[torch.Tensor, float]:
"""Compute the value loss.
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Copied from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
Args:
vpreds (`torch.FloatTensor`):
Predicted values of the value head, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
returns: (`torch.FloatTensor`):
Ground truth returns, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange_value: (float)
The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
Returns:
vf_loss: a scalar (`torch.FloatTensor`):
value function loss
vf_clipfrac: a float
The ratio of vf being clipped
"""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = torch.square(vpreds - returns)
vf_losses2 = torch.square(vpredclipped - returns)
vf_loss = 0.5 * VF.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = VF.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
return vf_loss, vf_clipfrac
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.Tensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
def kl_penalty(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
"""Compute KL divergence given log_probs and ref_log_probs.
Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
Args:
logprob:
ref_logprob:
log_probs: torch.Tensor
ref_log_probs: torch.Tensor
Returns:
kl_div: torch.Tensor
"""
log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
if kl_penalty == "kl":
return logprob - ref_logprob
return log_probs - ref_log_probs
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
return (log_probs - ref_log_probs).abs()
if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()
return 0.5 * (log_probs - ref_log_probs).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
# URL http://joschu.net/blog/kl-approx.html
if kl_penalty == "low_var_kl":
kl = ref_logprob - logprob
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
kl = ref_log_probs - log_probs
kld = (kl.exp() - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
raise NotImplementedError
raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")
......@@ -16,81 +16,100 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by o
"""
import json
import torch
import ray
from omegaconf import OmegaConf
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.config import PPOConfig
from verl.trainer.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
from verl.utils import get_processor, get_tokenizer
from verl.workers.fsdp_workers import FSDPWorker
from verl.workers.reward import CustomRewardManager
from ..single_controller.ray import RayWorkerGroup
from ..utils.tokenizer import get_processor, get_tokenizer
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import CustomRewardManager
from .config import PPOConfig
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
# please make sure main_task is not scheduled on head
@ray.remote(num_cpus=1)
class Runner:
"""A runner for RL training."""
def run(self, config: PPOConfig):
# print config
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer
tokenizer = get_tokenizer(
config.worker.actor.model.model_path,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
processor = get_processor(
config.worker.actor.model.model_path,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
# define worker classes
ray_worker_group_cls = RayWorkerGroup
role_worker_mapping = {
Role.ActorRollout: ray.remote(FSDPWorker),
Role.Critic: ray.remote(FSDPWorker),
Role.RefPolicy: ray.remote(FSDPWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
def main():
cli_args = OmegaConf.from_cli()
file_config = OmegaConf.load(cli_args.config)
del cli_args.config
default_config = OmegaConf.structured(PPOConfig())
ppo_config = OmegaConf.merge(default_config, file_config, cli_args)
if hasattr(cli_args, "config"):
config_path = cli_args.pop("config", None)
file_config = OmegaConf.load(config_path)
default_config = OmegaConf.merge(default_config, file_config)
ppo_config = OmegaConf.merge(default_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config)
# this is for local ray cluster
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.get(main_task.remote(ppo_config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config: PPOConfig):
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer
tokenizer = get_tokenizer(config.worker.actor.model.model_path)
processor = get_processor(config.worker.actor.model.model_path, use_fast=True)
# define worker classes
ray_worker_group_cls = RayWorkerGroup
role_worker_mapping = {
Role.ActorRollout: ray.remote(FSDPWorker),
Role.Critic: ray.remote(FSDPWorker),
Role.RefPolicy: ray.remote(FSDPWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
val_reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
# for rocm
if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(),
ignore_reinit_error=True,
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
else:
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
runner = Runner.remote()
ray.get(runner.run.remote(ppo_config))
if __name__ == "__main__":
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List
import numpy as np
import torch
from ..protocol import DataProto
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
return {key: np.mean(value) for key, value in metrics.items()}
def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]:
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
max_response_length = batch.batch["responses"].size(-1)
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float()
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
num_response_tokens = torch.sum(batch.batch["response_mask"]).item()
num_overall_tokens = sum(batch.meta_info["global_token_num"])
num_tokens_of_section = {
**dict.fromkeys(["gen", "reward"], num_response_tokens),
**dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens),
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
total_num_tokens = sum(batch.meta_info["global_token_num"])
time = timing_raw["step"]
return {
"perf/total_num_tokens": total_num_tokens,
"perf/time_per_step": time,
"perf/throughput": total_num_tokens / (time * n_gpus),
}
......@@ -18,54 +18,71 @@ This trainer supports model-agonistic model initialization with huggingface
import os
import uuid
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Any, Dict, Optional, Type
from enum import Enum, IntEnum, auto
from typing import Any, Callable, Dict, List, Optional, Type
import numpy as np
import ray
import torch
from codetiming import Timer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from ray.experimental.tqdm_ray import tqdm
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer import core_algos
from verl.trainer.config import PPOConfig
from verl.utils.rl_dataset import RLHFDataset, collate_fn
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import Tracking
from verl.workers.fsdp_workers import FSDPWorker
from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto
from ..single_controller.base import Worker
from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.dataset import RLHFDataset, collate_fn
from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..workers.fsdp_workers import FSDPWorker
from . import core_algos
from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
WorkerType = Type[Worker]
class Role(Enum):
class Role(IntEnum):
"""
To create more roles dynamically, you can subclass Role and add new members
"""
Actor = 0
Rollout = 1
ActorRollout = 2
Critic = 3
RefPolicy = 4
RewardModel = 5
ActorRolloutRef = 6
Actor = auto()
Rollout = auto()
ActorRollout = auto()
Critic = auto()
RefPolicy = auto()
RewardModel = auto()
ActorRolloutRef = auto()
class AdvantageEstimator(str, Enum):
"""
Using an enumeration class to avoid spelling errors in adv_estimator
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REMAX = "remax"
RLOO = "rloo"
@dataclass
class ResourcePoolManager:
"""
Define a resource pool specification. Resource pool will be initialized first.
Mapping
"""
resource_pool_spec: dict[str, list[int]]
......@@ -82,23 +99,41 @@ class ResourcePoolManager:
)
self.resource_pool_dict[resource_pool_name] = resource_pool
self._check_resource_available()
def get_resource_pool(self, role: Role) -> RayResourcePool:
"""Get the resource pool of the worker_cls"""
"""Get the resource pool of the worker."""
return self.resource_pool_dict[self.mapping[role]]
def get_n_gpus(self) -> int:
"""Get the number of gpus in this cluster."""
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
responses = data.batch["responses"]
response_length = responses.size(1)
def _check_resource_available(self):
"""Check if the resource pool can be satisfied in this ray cluster."""
node_available_resources = ray.state.available_resources_per_node()
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}
# check total required gpus can be satisfied
total_available_gpus = sum(node_available_gpus.values())
total_required_gpus = sum(
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
)
if total_available_gpus < total_required_gpus:
raise ValueError(
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}."
)
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"):
token_level_scores = data.batch["token_level_scores"]
batch_size = data.batch.batch_size[0]
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
response_mask = data.batch["response_mask"]
# compute kl between ref_policy and current policy
if "ref_log_prob" in data.batch.keys():
if "ref_log_probs" in data.batch.keys():
kld = core_algos.kl_penalty(
data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty
) # (batch_size, response_length)
kld = kld * response_mask
beta = kl_ctrl.value
......@@ -108,191 +143,49 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
token_level_rewards = token_level_scores - beta * kld
current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
# According to https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L880
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch["token_level_rewards"] = token_level_rewards
metrics = {"critic/kl": current_kl, "critic/kl_coeff": beta}
metrics = {"critic/kl": current_kl, "critic/kl_coef": beta}
return data, metrics
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
# prepare response group
# TODO: add other ways to estimate advantages
if adv_estimator == "gae":
def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0):
token_level_rewards = data.batch["token_level_rewards"]
response_mask = data.batch["response_mask"]
index = data.non_tensor_batch["uid"]
if adv_estimator == AdvantageEstimator.GAE:
values = data.batch["values"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
token_level_rewards = data.batch["token_level_rewards"]
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=token_level_rewards, values=values, eos_mask=response_mask, gamma=gamma, lam=lam
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
elif adv_estimator == "grpo":
token_level_rewards = data.batch["token_level_rewards"]
index = data.non_tensor_batch["uid"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
elif adv_estimator == AdvantageEstimator.GRPO:
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
elif adv_estimator == "reinforce_plus_plus":
token_level_rewards = data.batch["token_level_rewards"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
elif adv_estimator == "remax":
token_level_rewards = data.batch["token_level_rewards"]
index = data.non_tensor_batch["uid"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
elif adv_estimator == AdvantageEstimator.REMAX:
reward_baselines = data.batch["reward_baselines"]
advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards=token_level_rewards, reward_baselines=reward_baselines, eos_mask=response_mask
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
elif adv_estimator == AdvantageEstimator.RLOO:
advantages, returns = core_algos.compute_rloo_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index
)
else:
raise NotImplementedError
return data
def reduce_metrics(metrics: Dict[str, Any]):
for key, val in metrics.items():
metrics[key] = np.mean(val)
return metrics
def _compute_response_info(batch: DataProto):
response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-response_length]
response_mask = batch.batch["attention_mask"][:, -response_length:]
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)
return dict(
response_mask=response_mask,
prompt_length=prompt_length,
response_length=response_length,
)
def compute_data_metrics(batch: DataProto, use_critic: bool = True):
# TODO: add response length
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
max_response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
num_response_tokens = torch.sum(response_info["response_length"]).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens
num_tokens_of_section = {
"gen": num_response_tokens,
**{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
data.batch["advantages"] = advantages
data.batch["returns"] = returns
return data
@contextmanager
......@@ -308,8 +201,6 @@ class RayPPOTrainer:
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(
self,
config: PPOConfig,
......@@ -318,8 +209,8 @@ class RayPPOTrainer:
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn=None,
val_reward_fn=None,
reward_fn: Callable = None,
val_reward_fn: Callable = None,
):
self.tokenizer = tokenizer
self.processor = processor
......@@ -328,42 +219,51 @@ class RayPPOTrainer:
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.worker.hybrid_engine
assert self.hybrid_engine, "Currently, only support hybrid engine"
if self.hybrid_engine:
assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()}"
assert Role.ActorRollout in role_worker_mapping, (
f"ActorRollout should be included in {role_worker_mapping.keys()}."
)
else:
raise NotImplementedError
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_reward_model = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
# define KL control
if self.use_reference_policy:
if Role.RefPolicy in role_worker_mapping and not config.algorithm.disable_kl:
self.use_reference_policy = True
self.kl_ctrl = core_algos.get_kl_controller(config.algorithm)
else:
self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.0)
self.use_reference_policy = False
self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0)
print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.")
if self.config.algorithm.adv_estimator == "gae":
if config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
elif self.config.algorithm.adv_estimator == "grpo":
self.use_critic = False
elif self.config.algorithm.adv_estimator == "reinforce_plus_plus":
self.use_critic = False
elif self.config.algorithm.adv_estimator == "remax":
self.use_critic = False
else:
raise NotImplementedError
self.use_critic = False
if config.algorithm.adv_estimator not in list(AdvantageEstimator):
raise NotImplementedError(f"Unknown advantage estimator: {config.algorithm.adv_estimator}.")
if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by global batch size.")
if self.use_critic and config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by global batch size.")
self._create_dataloader()
def _create_dataloader(self):
def _create_dataloader(self) -> None:
self.train_dataset = RLHFDataset(
data_path=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
answer_key=self.config.data.answer_key,
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
system_prompt=self.config.data.system_prompt,
......@@ -378,13 +278,14 @@ class RayPPOTrainer:
else:
sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = DataLoader(
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.rollout_batch_size,
sampler=sampler,
num_workers=8,
drop_last=True,
collate_fn=collate_fn,
sampler=sampler,
pin_memory=False,
drop_last=True,
)
self.val_dataset = RLHFDataset(
......@@ -392,24 +293,28 @@ class RayPPOTrainer:
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
answer_key=self.config.data.answer_key,
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
system_prompt=self.config.data.system_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
self.val_dataloader = DataLoader(
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset),
num_workers=8,
batch_size=len(self.val_dataset)
if self.config.data.val_batch_size == -1
else self.config.data.val_batch_size,
shuffle=False,
drop_last=False,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")
......@@ -423,20 +328,11 @@ class RayPPOTrainer:
self.config.worker.critic.optim.training_steps = training_steps
print(f"Total training steps: {self.training_steps}")
def _maybe_log_val_generations_to_wandb(self, inputs, outputs, scores):
"""Log a table of validation samples to wandb"""
generations_to_log = self.config.trainer.val_generations_to_log_to_wandb
if generations_to_log == 0:
def _maybe_log_val_generations(self, inputs: List[str], outputs: List[str], scores: List[float]) -> None:
"""Log a table of validation samples"""
if self.config.trainer.val_generations_to_log <= 0:
return
if generations_to_log > 0 and "wandb" not in self.config.trainer.logger:
print("WARNING: `val_generations_to_log_to_wandb` is set, but no wandb logger is found.")
return
import wandb
# Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores))
samples.sort(key=lambda x: x[0]) # Sort by input text
......@@ -445,43 +341,14 @@ class RayPPOTrainer:
rng = np.random.RandomState(42)
rng.shuffle(samples)
# Take first N samples after shuffling
samples = samples[:generations_to_log]
samples = samples[: self.config.trainer.val_generations_to_log]
self.logger.log_generation(samples, self.global_step)
# Create column names for all samples
columns = ["step"] + sum(
[[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
)
if not hasattr(self, "validation_table"):
# Initialize the table on first call
self.validation_table = wandb.Table(columns=columns)
# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self.validation_table.data)
# Add new row with all data
row_data = []
row_data.append(self.global_steps)
for sample in samples:
row_data.extend(sample)
new_table.add_data(*row_data)
# Update reference and log
wandb.log({"val/generations": new_table}, step=self.global_steps)
self.validation_table = new_table
def _validate(self):
def _validate(self) -> Dict[str, Any]:
reward_tensor_lst = []
data_source_lst = []
# Lists to collect samples for the table
sample_inputs = []
sample_outputs = []
sample_scores = []
sample_inputs, sample_outputs, sample_scores = [], [], []
reward_metrics_lst = defaultdict(list)
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
# Store original inputs
......@@ -489,10 +356,10 @@ class RayPPOTrainer:
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
if "pixel_values" in test_batch.non_tensor_batch.keys():
if "multi_modal_inputs" in test_batch.non_tensor_batch.keys():
test_gen_batch = test_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["pixel_values", "image_grid_thw", "raw_prompt_ids", "images"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
)
else:
test_gen_batch = test_batch.pop(
......@@ -500,15 +367,10 @@ class RayPPOTrainer:
non_tensor_batch_keys=["raw_prompt_ids"],
)
test_gen_batch.meta_info = {"do_sample": False}
# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(
test_gen_batch, self.actor_rollout_wg.world_size
)
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
test_gen_batch.meta_info = self.config.worker.rollout.val_override_config
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
print("validation generation end")
# Store generated outputs
......@@ -519,40 +381,24 @@ class RayPPOTrainer:
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
reward_tensor = self.val_reward_fn(test_batch)
reward_tensor, reward_metrics = self.val_reward_fn(test_batch)
# Store scores
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)
reward_tensor_lst.append(reward_tensor)
data_source_lst.append(
test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])
)
for key, value in reward_metrics.items():
reward_metrics_lst[key].extend(value)
self._maybe_log_val_generations_to_wandb(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item()
val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()}
return {"val/reward_score": reward_score, **val_reward_metrics}
reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,)
data_sources = np.concatenate(data_source_lst, axis=0)
# evaluate test_score based on data source
data_source_reward = {}
for i in range(reward_tensor.shape[0]):
data_source = data_sources[i]
if data_source not in data_source_reward:
data_source_reward[data_source] = []
data_source_reward[data_source].append(reward_tensor[i].item())
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards)
return metric_dict
def init_workers(self):
def init_workers(self) -> None:
"""Init resource pool and worker group"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
# create actor and rollout
......@@ -594,7 +440,7 @@ class RayPPOTrainer:
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
all_wg: Dict[str, FSDPWorker] = {}
self.wg_dicts = []
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
......@@ -605,62 +451,80 @@ class RayPPOTrainer:
self.wg_dicts.append(wg_dict)
if self.use_critic:
self.critic_wg: FSDPWorker = all_wg["critic"]
self.critic_wg = all_wg["critic"]
self.critic_wg.init_model()
if self.use_reference_policy:
self.ref_policy_wg: FSDPWorker = all_wg["ref"]
self.ref_policy_wg = all_wg["ref"]
self.ref_policy_wg.init_model()
if self.use_reward_model:
self.rm_wg: FSDPWorker = all_wg["rm"]
self.rm_wg = all_wg["rm"]
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg: FSDPWorker = all_wg["actor_rollout"]
self.actor_rollout_wg = all_wg["actor_rollout"]
self.actor_rollout_wg.init_model()
def _save_checkpoint(self):
# path: {save_checkpoint_path}/global_step_{global_steps}/actor
local_global_step_folder = os.path.join(
self.config.trainer.save_checkpoint_path, f"global_step_{self.global_steps}"
)
actor_local_path = os.path.join(local_global_step_folder, "actor")
self.actor_rollout_wg.save_checkpoint(
actor_local_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt,
def _save_checkpoint(self) -> None:
# path: {save_checkpoint_path}/global_step_{global_step}/{actor,critic}
remove_obsolete_ckpt(
self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit
)
folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}")
actor_path = os.path.join(folder_path, "actor")
self.actor_rollout_wg.save_checkpoint(actor_path)
if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, "critic")
self.critic_wg.save_checkpoint(
critic_local_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt,
)
critic_path = os.path.join(folder_path, "critic")
self.critic_wg.save_checkpoint(critic_path)
local_latest_checkpointed_iteration = os.path.join(
self.config.trainer.save_checkpoint_path, "latest_checkpointed_iteration.txt"
)
with open(local_latest_checkpointed_iteration, "w") as f:
f.write(str(self.global_steps))
dataloader_path = os.path.join(folder_path, "dataloader.pt")
dataloader_state_dict = self.train_dataloader.state_dict()
torch.save(dataloader_state_dict, dataloader_path)
def _load_checkpoint(self):
last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER)
with open(last_global_step_path, "w") as f:
f.write(str(self.global_step))
def _load_checkpoint(self) -> None:
if self.config.trainer.load_checkpoint_path is None:
return
print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}")
if "global_step_" not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]:
raise ValueError("`load_checkpoint_path` should end with `global_step_*`.")
print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}.")
self.global_step = int(self.config.trainer.load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1])
actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor")
critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic")
self.actor_rollout_wg.load_checkpoint(
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
self.actor_rollout_wg.load_checkpoint(actor_path)
if self.use_critic:
self.critic_wg.load_checkpoint(
critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic")
self.critic_wg.load_checkpoint(critic_path)
dataloader_path = os.path.join(self.config.trainer.load_checkpoint_path, "dataloader.pt")
if os.path.exists(dataloader_path):
dataloader_state_dict = torch.load(dataloader_path, weights_only=False)
self.train_dataloader.load_state_dict(dataloader_state_dict)
else:
print(f"No dataloader state found at {dataloader_path}, will start from scratch.")
def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None:
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch["attention_mask"]
batch_size = attention_mask.shape[0]
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
world_size = self.actor_rollout_wg.world_size
global_partition_lst = get_seqlen_balanced_partitions(
global_seqlen_lst, k_partitions=world_size, equal_size=True
)
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
batch.reorder(global_idx)
global_balance_stats = log_seqlen_unbalance(
seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
)
metrics.update(global_balance_stats)
def fit(self):
"""
......@@ -668,13 +532,9 @@ class RayPPOTrainer:
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=self.config.to_dict(),
)
self.global_steps = 0
self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict())
self.global_step = 0
val_metrics: Optional[Dict[str, Any]] = None
# load checkpoint before doing anything
self._load_checkpoint()
......@@ -683,27 +543,24 @@ class RayPPOTrainer:
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.val_before_train:
val_metrics = self._validate()
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
self.logger.log(data=val_metrics, step=self.global_step)
if self.config.trainer.val_only:
return
for _ in range(self.config.trainer.total_episodes):
for batch_dict in self.train_dataloader:
self.global_steps += 1
if self.global_steps >= self.training_steps:
for _ in tqdm(range(self.config.trainer.total_episodes), desc="Episode", position=0):
for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1):
self.global_step += 1
if self.global_step > self.training_steps:
break
metrics = {}
timing_raw = {}
metrics, timing_raw = {}, {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
if "pixel_values" in batch.non_tensor_batch.keys():
if "multi_modal_inputs" in batch.non_tensor_batch.keys():
gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["pixel_values", "image_grid_thw", "raw_prompt_ids", "images"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
)
else:
gen_batch = batch.pop(
......@@ -719,17 +576,15 @@ class RayPPOTrainer:
if self.config.algorithm.adv_estimator == "remax":
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_batch.meta_info["temperature"] = 0.0
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor, _ = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch["uid"] = np.array(
......@@ -739,24 +594,37 @@ class RayPPOTrainer:
batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
# compute reward
with _timer("reward", timing_raw):
if self.use_reward_model:
raise NotImplementedError("Reward model is not supported yet.")
# we combine with rule-based rm
reward_tensor, reward_metrics = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {
f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items()
}
metrics.update(reward_metrics)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
# self._balance_batch(batch, metrics=metrics) # TODO: re-enable balance batch
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
with _timer("old", timing_raw):
old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
batch = batch.union(old_log_probs)
# compute ref_log_probs
if self.use_reference_policy:
# compute reference log_prob
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_probs)
# compute values
if self.use_critic:
......@@ -765,18 +633,8 @@ class RayPPOTrainer:
batch = batch.union(values)
with _timer("adv", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_reward_model:
raise NotImplementedError
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
# compute rewards. apply_kl_penalty if available
if not self.config.worker.actor.use_kl_loss: # not grpo
# apply kl penalty if available
if not self.config.algorithm.use_kl_loss and self.use_reference_policy: # apply kl penalty to reward
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
......@@ -790,7 +648,6 @@ class RayPPOTrainer:
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.worker.rollout.n,
)
# update critic
......@@ -798,43 +655,51 @@ class RayPPOTrainer:
with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
metrics.update(critic_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
# update actor
if self.config.trainer.critic_warmup <= self.global_step:
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
metrics.update(actor_metrics)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and self.global_steps % self.config.trainer.test_freq == 0
and self.config.trainer.val_freq > 0
and self.global_step % self.config.trainer.val_freq == 0
):
with _timer("testing", timing_raw):
val_metrics: dict = self._validate()
with _timer("validation", timing_raw):
val_metrics = self._validate()
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0:
if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
self.logger.log(data=metrics, step=self.global_step)
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f"Final validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
self._save_checkpoint()
if (
val_metrics is None
or self.config.trainer.val_freq <= 0
or self.global_step % self.config.trainer.val_freq != 0
):
val_metrics = self._validate()
self.logger.log(data=val_metrics, step=self.global_step)
print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}")
if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0:
self._save_checkpoint()
......@@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import get_processor, get_tokenizer
__all__ = ["get_processor", "get_tokenizer"]
......@@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt
__all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"]
......@@ -11,20 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import torch.distributed
import torch.distributed as dist
from filelock import FileLock
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import PreTrainedTokenizer, ProcessorMixin
class BaseCheckpointManager:
CHECKPOINT_TRACKER = "latest_global_step.txt"
class BaseCheckpointManager(ABC):
"""
A checkpoint manager that saves and loads
- model
......@@ -44,42 +51,27 @@ class BaseCheckpointManager:
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin
processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
):
self.previous_global_step = None
self.previous_save_local_path = None
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.processor = processor
self.processing_class = processing_class
assert isinstance(self.model, FSDP)
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
@abstractmethod
def load_checkpoint(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod
def save_checkpoint(self, *args, **kwargs):
raise NotImplementedError
def remove_previous_save_local_path(self):
if not self.previous_save_local_path:
return
abs_path = os.path.abspath(self.previous_save_local_path)
print(f"Checkpoint manager remove previous save local path: {abs_path}")
if not os.path.exists(abs_path):
return
# remove previous local_path
shutil.rmtree(abs_path, ignore_errors=True)
@staticmethod
def local_mkdir(path):
def local_mkdir(path: str) -> str:
if not os.path.isabs(path):
working_dir = os.getcwd()
path = os.path.join(working_dir, path)
......@@ -89,18 +81,16 @@ class BaseCheckpointManager:
lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
try:
with FileLock(lock_path, timeout=60): # Add timeout
# make a new dir
with FileLock(lock_path, timeout=60):
os.makedirs(path, exist_ok=True)
except Exception as e:
print(f"Warning: Failed to acquire lock for {path}: {e}")
# Even if the lock is not acquired, try to create the directory
os.makedirs(path, exist_ok=True)
os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory
return path
@staticmethod
def get_rng_state():
def get_rng_state() -> Dict[str, Any]:
rng_state = {
"cpu": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state(),
......@@ -110,14 +100,14 @@ class BaseCheckpointManager:
return rng_state
@staticmethod
def load_rng_state(rng_state):
def load_rng_state(rng_state: Dict[str, Any]):
torch.set_rng_state(rng_state["cpu"])
torch.cuda.set_rng_state(rng_state["cuda"])
np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["random"])
def find_latest_ckpt_path(path, directory_format="global_step_{}"):
def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]:
if path is None:
return None
......@@ -128,6 +118,7 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
with open(tracker_file, "rb") as f:
iteration = int(f.read().decode())
ckpt_path = os.path.join(path, directory_format.format(iteration))
if not os.path.exists(ckpt_path):
print("Checkpoint does not exist: %s", ckpt_path)
......@@ -137,8 +128,33 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
return ckpt_path
def get_checkpoint_tracker_filename(root_path: str):
def get_checkpoint_tracker_filename(root_path: str) -> str:
"""
Tracker file rescords the latest chckpoint during training to restart from.
"""
return os.path.join(root_path, "latest_checkpointed_iteration.txt")
return os.path.join(root_path, CHECKPOINT_TRACKER)
def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"):
"""
Remove the obsolete checkpoints that exceed the save_limit.
"""
if save_limit <= 0:
return
if not os.path.exists(path):
return
pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)")
ckpt_folders = []
for folder in os.listdir(path):
if match := re.match(pattern, folder):
step = int(match.group(1))
if step < global_step:
ckpt_folders.append((step, folder))
ckpt_folders.sort(reverse=True)
for _, folder in ckpt_folders[save_limit - 1 :]:
folder_path = os.path.join(path, folder)
shutil.rmtree(folder_path, ignore_errors=True)
print(f"Removed obsolete checkpoint: {folder_path}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment