Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
...@@ -12,11 +12,21 @@ ...@@ -12,11 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import re
from setuptools import find_packages, setup from setuptools import find_packages, setup
def get_requires(): def get_version() -> str:
with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"__version__\W*=\W*\"([^\"]+)\""
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
...@@ -31,18 +41,19 @@ extra_require = { ...@@ -31,18 +41,19 @@ extra_require = {
def main(): def main():
setup( setup(
name="verl", name="verl",
version="0.2.0.dev0", version=get_version(),
package_dir={"": "."}, description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL",
packages=find_packages(where="."), long_description=open("README.md", encoding="utf-8").read(),
url="https://github.com/volcengine/verl", long_description_content_type="text/markdown",
license="Apache 2.0",
author="verl", author="verl",
author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn", author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
description="", license="Apache 2.0 License",
url="https://github.com/volcengine/verl",
package_dir={"": "."},
packages=find_packages(where="."),
python_requires=">=3.9.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require, extras_require=extra_require,
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
) )
......
...@@ -12,8 +12,4 @@ ...@@ -12,8 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .protocol import DataProto
__all__ = ["DataProto"]
__version__ = "0.2.0.dev" __version__ = "0.2.0.dev"
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .transformers.flash_attention_utils import flash_attention_forward
from .transformers.qwen2_vl import qwen2_vl_attn_forward
def apply_ulysses_patch(model_type: str) -> None:
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
else:
raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
from ...utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
_flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def prepare_fa2_from_position_ids(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seqlens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
def _custom_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
query_length: int,
is_causal: bool = True,
position_ids: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
deterministic: Optional[bool] = None,
**kwargs,
):
"""
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if _flash_supports_deterministic:
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
if kwargs.get("softcap") is not None:
flash_kwargs["softcap"] = kwargs.pop("softcap")
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype=torch.bfloat16
)
sp_size = get_ulysses_sequence_parallel_world_size()
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids[0]
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=kwargs.pop("dropout", 0.0),
softmax_scale=kwargs.pop("softmax_scale", None),
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_length,
is_causal=is_causal,
sliding_window=sliding_window,
use_top_left_mask=use_top_left_mask,
deterministic=deterministic,
**kwargs,
) # do not pass position_ids to old flash_attention_forward
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
return attn_output
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# This is before the transpose
q_len = query.shape[2]
# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
attn_output = _custom_flash_attention_forward(
query,
key,
value,
attention_mask,
query_length=q_len,
is_causal=True,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_flash_use_top_left_mask,
**kwargs,
)
return attn_output, None
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Based on:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from .flash_attention_utils import flash_attention_forward
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLAttention,
apply_multimodal_rotary_pos_emb,
repeat_kv,
)
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
except ImportError:
pass
def get_rope_index(
processor: "Qwen2VLProcessor",
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
video_grid_thw: Optional[torch.Tensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
"""
spatial_merge_size = processor.image_processor.merge_size
tokens_per_second = 2
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
image_index, video_index = 0, 0
input_ids = input_ids[attention_mask == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
else:
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
return position_ids
def qwen2_vl_attn_forward(
self: "Qwen2VLAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, None, None]:
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
attn_output, _ = flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=dropout_rate,
sliding_window=sliding_window,
position_ids=position_ids, # important: pass position ids
) # (batch_size, seq_length, num_head / sp_size, head_size)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None, None
...@@ -26,13 +26,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -26,13 +26,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import ray import ray
import torch import torch
import torch.distributed as dist
from numpy.typing import NDArray from numpy.typing import NDArray
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from verl.utils.py_functional import union_two_dict from .utils.py_functional import union_two_dict
try: try:
...@@ -89,20 +88,21 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten ...@@ -89,20 +88,21 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
) )
for key, value in tensor_dict2.items(): for key in tensor_dict2.keys():
if key in tensor_dict1 and not torch.equal(tensor_dict1[key], value): if key in tensor_dict1 and not torch.equal(tensor_dict1[key], tensor_dict2[key]):
raise ValueError(f"Key already exists: {key}.") raise ValueError(f"Key already exists: {key}.")
tensor_dict1[key] = value tensor_dict1[key] = tensor_dict2[key]
return tensor_dict1 return tensor_dict1
def union_numpy_dict( def union_numpy_dict(tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray]) -> Dict[str, NDArray]:
tensor_dict1: Dict[str, Union[List, NDArray]], tensor_dict2: Dict[str, Union[List, NDArray]] for key in tensor_dict2.keys():
) -> Dict[str, Union[List, NDArray]]: if key in tensor_dict1:
for key, value in tensor_dict2.items(): assert isinstance(tensor_dict2[key], np.ndarray)
if key in tensor_dict1 and isinstance(value, np.ndarray) and not np.all(tensor_dict1[key] == value): assert isinstance(tensor_dict1[key], np.ndarray)
if not np.all(tensor_dict1[key] == tensor_dict2[key]):
raise ValueError(f"Key already exists: {key}.") raise ValueError(f"Key already exists: {key}.")
tensor_dict1[key] = tensor_dict2[key] tensor_dict1[key] = tensor_dict2[key]
...@@ -151,6 +151,7 @@ def collate_fn(data_items: list["DataProtoItem"]): ...@@ -151,6 +151,7 @@ def collate_fn(data_items: list["DataProtoItem"]):
batch = torch.stack(batch).contiguous() batch = torch.stack(batch).contiguous()
non_tensor_batch = batch_collate(non_tensor_batch) non_tensor_batch = batch_collate(non_tensor_batch)
non_tensor_batch = {key: np.array(value, dtype=object) for key, value in non_tensor_batch.items()}
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
...@@ -189,7 +190,8 @@ class DataProto: ...@@ -189,7 +190,8 @@ class DataProto:
def __getitem__(self, item): def __getitem__(self, item):
tensor_data = self.batch[item] tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self): def __getstate__(self):
buffer = io.BytesIO() buffer = io.BytesIO()
...@@ -229,7 +231,6 @@ class DataProto: ...@@ -229,7 +231,6 @@ class DataProto:
size_of_numpy_array = 0 size_of_numpy_array = 0
for value in self.non_tensor_batch.values(): for value in self.non_tensor_batch.values():
if isinstance(value, np.ndarray):
size_of_numpy_array += value.nbytes size_of_numpy_array += value.nbytes
size_of_numpy_array /= 1024**3 size_of_numpy_array /= 1024**3
...@@ -254,13 +255,13 @@ class DataProto: ...@@ -254,13 +255,13 @@ class DataProto:
assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}." assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}."
@classmethod @classmethod
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, list, NDArray]], meta_info=None): def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, NDArray]], meta_info=None):
tensors = {} tensors = {}
non_tensors = {} non_tensors = {}
for key, value in data.items(): for key, value in data.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
tensors[key] = value tensors[key] = value
elif isinstance(value, (list, np.ndarray)): elif isinstance(value, np.ndarray):
non_tensors[key] = value non_tensors[key] = value
else: else:
raise ValueError(f"Unsupported type in data {type(value)}") raise ValueError(f"Unsupported type in data {type(value)}")
...@@ -472,8 +473,6 @@ class DataProto: ...@@ -472,8 +473,6 @@ class DataProto:
assert len(self) % chunks == 0, ( assert len(self) % chunks == 0, (
f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}."
) )
chunk_size = len(self) // chunks
if self.batch is not None: if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0) batch_lst = self.batch.chunk(chunks=chunks, dim=0)
else: else:
...@@ -481,12 +480,8 @@ class DataProto: ...@@ -481,12 +480,8 @@ class DataProto:
non_tensor_batch_lst = [{} for _ in range(chunks)] non_tensor_batch_lst = [{} for _ in range(chunks)]
for key, value in self.non_tensor_batch.items(): for key, value in self.non_tensor_batch.items():
assert isinstance(value, (list, np.ndarray)) assert isinstance(value, np.ndarray)
if isinstance(value, np.ndarray):
non_tensor_lst = np.array_split(value, chunks) non_tensor_lst = np.array_split(value, chunks)
else:
non_tensor_lst = [value[i : i + chunk_size] for i in range(0, len(self), chunk_size)]
assert len(non_tensor_lst) == chunks assert len(non_tensor_lst) == chunks
for i in range(chunks): for i in range(chunks):
non_tensor_batch_lst[i][key] = non_tensor_lst[i] non_tensor_batch_lst[i][key] = non_tensor_lst[i]
...@@ -524,12 +519,7 @@ class DataProto: ...@@ -524,12 +519,7 @@ class DataProto:
non_tensor_batch = batch_collate([d.non_tensor_batch for d in data]) non_tensor_batch = batch_collate([d.non_tensor_batch for d in data])
for key, value in non_tensor_batch.items(): for key, value in non_tensor_batch.items():
if isinstance(value[0], np.ndarray):
non_tensor_batch[key] = np.concatenate(value, axis=0) non_tensor_batch[key] = np.concatenate(value, axis=0)
else:
non_tensor_batch[key] = []
for item in value:
non_tensor_batch[key].extend(item)
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
...@@ -574,16 +564,10 @@ class DataProto: ...@@ -574,16 +564,10 @@ class DataProto:
repeated_non_tensor_batch = {} repeated_non_tensor_batch = {}
for key, value in self.non_tensor_batch.items(): for key, value in self.non_tensor_batch.items():
if isinstance(value, np.ndarray):
if interleave: if interleave:
repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0) repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
else: else:
repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1)) repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
else:
if interleave:
repeated_non_tensor_batch[key] = [item for item in value for _ in range(repeat_times)]
else:
repeated_non_tensor_batch[key] = [item for _ in range(repeat_times) for item in value]
return DataProto( return DataProto(
batch=repeated_batch, batch=repeated_batch,
...@@ -591,39 +575,6 @@ class DataProto: ...@@ -591,39 +575,6 @@ class DataProto:
meta_info=self.meta_info, meta_info=self.meta_info,
) )
def broadcast(self, src: int, group: Optional[ProcessGroup] = None):
for key in self.batch.sorted_keys:
dist.broadcast(self.batch[key], src=src, group=group, async_op=False)
object_list = [self.non_tensor_batch]
dist.broadcast_object_list(object_list, src=src, group=group)
self.non_tensor_batch = object_list[0]
def all_gather(self, group: Optional[ProcessGroup] = None):
world_size = dist.get_world_size(group)
output = {}
for key in self.batch.sorted_keys:
value = self.batch[key].contiguous()
output[key] = [torch.empty_like(value) for _ in range(world_size)]
dist.all_gather(output[key], value, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=0)
self.batch = TensorDict(output, batch_size=self.batch.batch_size[0] * world_size)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(world_size)]
dist.all_gather_object(all_non_tensor_batch, self.non_tensor_batch, group=group)
non_tensor_batch = defaultdict(list)
for key, value in self.non_tensor_batch.items():
if isinstance(value, np.ndarray):
non_tensor_batch[key] = np.concatenate([batch[key] for batch in all_non_tensor_batch])
else:
for batch in all_non_tensor_batch:
non_tensor_batch[key].extend(batch[key])
self.non_tensor_batch = non_tensor_batch
self.check_consistency()
@dataclass @dataclass
class DataProtoFuture: class DataProtoFuture:
...@@ -664,10 +615,53 @@ class DataProtoFuture: ...@@ -664,10 +615,53 @@ class DataProtoFuture:
return arg_future_lst return arg_future_lst
def get(self): def get(self):
output = ray.get(self.futures) # dp_size. outputs = ray.get(self.futures) # dp_size.
for o in output: for output in outputs:
assert isinstance(o, DataProto) assert isinstance(output, DataProto)
output = self.collect_fn(output) # select dp, concat
outputs = self.collect_fn(outputs) # select dp, concat
if self.dispatch_fn is not None: if self.dispatch_fn is not None:
output = self.dispatch_fn(output) # split in batch dim, select using dp outputs = self.dispatch_fn(outputs) # split in batch dim, select using dp
return outputs
def allgather_dict_tensors(
tensors: Union[Dict[str, torch.Tensor], TensorDict], size: int, group: ProcessGroup, dim: int = 0
) -> Union[Dict[str, torch.Tensor], TensorDict]:
"""
TODO: optimize this.
- We can use async ops
- We can use only one allgather
"""
if isinstance(tensors, TensorDict):
is_tensor_dict = True
tensors_as_dict = tensors.to_dict()
else:
tensors_as_dict = tensors
is_tensor_dict = False
output = {}
sorted_keys = sorted(tensors_as_dict.keys())
for key in sorted_keys:
val = tensors_as_dict[key]
output[key] = [torch.empty_like(val) for _ in range(size)]
torch.distributed.all_gather(output[key], val, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=dim)
if is_tensor_dict:
output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)
return output return output
def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> None:
# Note that this is an inplace operator just like torch.distributed.all_gather
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
...@@ -14,3 +14,6 @@ ...@@ -14,3 +14,6 @@
from .worker import Worker from .worker import Worker
from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
__all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
...@@ -12,14 +12,18 @@ ...@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from enum import Enum from enum import Enum, auto
from functools import wraps from functools import wraps
from types import FunctionType from types import FunctionType
from typing import Dict, List from typing import TYPE_CHECKING, Dict, List, Literal, Union
import ray import ray
from verl.protocol import DataProto, DataProtoFuture from ...protocol import DataProto, DataProtoFuture
if TYPE_CHECKING:
from .worker_group import WorkerGroup
# here we add a magic number of avoid user-defined function already have this attribute # here we add a magic number of avoid user-defined function already have this attribute
...@@ -27,13 +31,13 @@ MAGIC_ATTR = "attrs_3141562937" ...@@ -27,13 +31,13 @@ MAGIC_ATTR = "attrs_3141562937"
class Dispatch(Enum): class Dispatch(Enum):
RANK_ZERO = 0 RANK_ZERO = auto()
ONE_TO_ALL = 1 ONE_TO_ALL = auto()
ALL_TO_ALL = 2 ALL_TO_ALL = auto()
DP_COMPUTE = 3 DP_COMPUTE = auto()
DP_COMPUTE_PROTO = 4 DP_COMPUTE_PROTO = auto()
DP_COMPUTE_PROTO_WITH_FUNC = 5 DP_COMPUTE_PROTO_WITH_FUNC = auto()
DP_COMPUTE_METRIC = 6 DP_COMPUTE_METRIC = auto()
class Execute(Enum): class Execute(Enum):
...@@ -41,100 +45,85 @@ class Execute(Enum): ...@@ -41,100 +45,85 @@ class Execute(Enum):
RANK_ZERO = 1 RANK_ZERO = 1
def _split_args_kwargs_data_proto(chunks, *args, **kwargs): def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
splitted_args = [] splitted_args = []
for arg in args: for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture)) assert isinstance(arg, (DataProto, DataProtoFuture))
splitted_args.append(arg.chunk(chunks=chunks)) splitted_args.append(arg.chunk(chunks=chunks))
splitted_kwargs = {} splitted_kwargs = {}
for key, val in kwargs.items(): for key, value in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture)) assert isinstance(value, (DataProto, DataProtoFuture))
splitted_kwargs[key] = val.chunk(chunks=chunks) splitted_kwargs[key] = value.chunk(chunks=chunks)
return splitted_args, splitted_kwargs return splitted_args, splitted_kwargs
def dispatch_one_to_all(worker_group, *args, **kwargs): def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
args = tuple([arg] * worker_group.world_size for arg in args) args = tuple([arg] * worker_group.world_size for arg in args)
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
return args, kwargs return args, kwargs
def dispatch_all_to_all(worker_group, *args, **kwargs): def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
return args, kwargs return args, kwargs
def collect_all_to_all(worker_group, output): def collect_all_to_all(worker_group: "WorkerGroup", output):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
return output return output
def _concat_data_proto_or_future(output: List): def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
# make sure all the elements in output has the same type # make sure all the elements in output has the same type
for o in output: for output in outputs:
assert type(o) is type(output[0]) assert type(output) is type(outputs[0])
o = output[0] output = outputs[0]
if isinstance(o, DataProto): if isinstance(output, DataProto):
return DataProto.concat(output) return DataProto.concat(outputs)
elif isinstance(o, ray.ObjectRef): elif isinstance(output, ray.ObjectRef):
return DataProtoFuture.concat(output) return DataProtoFuture.concat(outputs)
else: else:
raise NotImplementedError raise NotImplementedError
def dispatch_dp_compute(worker_group, *args, **kwargs): def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup for arg in args:
assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
assert isinstance(worker_group, WorkerGroup)
return args, kwargs
for value in kwargs.values():
assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
def collect_dp_compute(worker_group, output): return args, kwargs
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert len(output) == worker_group.world_size
return output
def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
assert len(outputs) == worker_group.world_size
return outputs
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup) def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
return splitted_args, splitted_kwargs return splitted_args, splitted_kwargs
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
assert type(args[0]) is FunctionType # NOTE: The first one args is a function! assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
return splitted_args_with_func, splitted_kwargs return splitted_args_with_func, splitted_kwargs
def collect_dp_compute_data_proto(worker_group, output): def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
for o in output: for output in outputs:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}" assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
output = collect_dp_compute(worker_group, output) outputs = collect_dp_compute(worker_group, outputs)
return _concat_data_proto_or_future(output) return _concat_data_proto_or_future(outputs)
def get_predefined_dispatch_fn(dispatch_mode): def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
predefined_dispatch_mode_fn = { predefined_dispatch_mode_fn = {
Dispatch.ONE_TO_ALL: { Dispatch.ONE_TO_ALL: {
"dispatch_fn": dispatch_one_to_all, "dispatch_fn": dispatch_one_to_all,
...@@ -144,6 +133,10 @@ def get_predefined_dispatch_fn(dispatch_mode): ...@@ -144,6 +133,10 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_all_to_all, "dispatch_fn": dispatch_all_to_all,
"collect_fn": collect_all_to_all, "collect_fn": collect_all_to_all,
}, },
Dispatch.DP_COMPUTE: {
"dispatch_fn": dispatch_dp_compute,
"collect_fn": collect_dp_compute,
},
Dispatch.DP_COMPUTE_PROTO: { Dispatch.DP_COMPUTE_PROTO: {
"dispatch_fn": dispatch_dp_compute_data_proto, "dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute_data_proto, "collect_fn": collect_dp_compute_data_proto,
...@@ -152,11 +145,15 @@ def get_predefined_dispatch_fn(dispatch_mode): ...@@ -152,11 +145,15 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_dp_compute_data_proto_with_func, "dispatch_fn": dispatch_dp_compute_data_proto_with_func,
"collect_fn": collect_dp_compute_data_proto, "collect_fn": collect_dp_compute_data_proto,
}, },
Dispatch.DP_COMPUTE_METRIC: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute,
},
} }
return predefined_dispatch_mode_fn[dispatch_mode] return predefined_dispatch_mode_fn[dispatch_mode]
def get_predefined_execute_fn(execute_mode): def get_predefined_execute_fn(execute_mode: Execute):
""" """
Note that here we only asks execute_all and execute_rank_zero to be implemented Note that here we only asks execute_all and execute_rank_zero to be implemented
Leave the choice of how these two functions handle argument 'blocking' to users Leave the choice of how these two functions handle argument 'blocking' to users
...@@ -168,17 +165,17 @@ def get_predefined_execute_fn(execute_mode): ...@@ -168,17 +165,17 @@ def get_predefined_execute_fn(execute_mode):
return predefined_execute_mode_fn[execute_mode] return predefined_execute_mode_fn[execute_mode]
def _check_dispatch_mode(dispatch_mode): def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
assert isinstance(dispatch_mode, (Dispatch, Dict)), ( assert isinstance(dispatch_mode, (Dispatch, dict)), (
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
) )
if isinstance(dispatch_mode, Dict): if isinstance(dispatch_mode, dict):
necessary_keys = ["dispatch_fn", "collect_fn"] necessary_keys = ["dispatch_fn", "collect_fn"]
for key in necessary_keys: for key in necessary_keys:
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
def _check_execute_mode(execute_mode): def _check_execute_mode(execute_mode: Execute):
assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
...@@ -189,9 +186,10 @@ def _materialize_futures(*args, **kwargs): ...@@ -189,9 +186,10 @@ def _materialize_futures(*args, **kwargs):
arg = arg.get() arg = arg.get()
# add more type to materialize # add more type to materialize
new_args.append(arg) new_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, DataProtoFuture): for key, value in kwargs.items():
kwargs[k] = v.get() if isinstance(value, DataProtoFuture):
kwargs[key] = value.get()
new_args = tuple(new_args) new_args = tuple(new_args)
return new_args, kwargs return new_args, kwargs
......
...@@ -18,11 +18,13 @@ the class for Worker ...@@ -18,11 +18,13 @@ the class for Worker
import os import os
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import ray import ray
import torch
from verl.single_controller.base.decorator import Dispatch, Execute, register from .decorator import Dispatch, Execute, register
from verl.single_controller.base.register_center.ray import create_worker_group_register_center from .register_center.ray import create_worker_group_register_center
@dataclass @dataclass
...@@ -40,7 +42,7 @@ class DistGlobalInfo: ...@@ -40,7 +42,7 @@ class DistGlobalInfo:
class WorkerHelper: class WorkerHelper:
def _get_node_ip(self): def _get_node_ip(self) -> str:
host_ipv4 = os.getenv("MY_HOST_IP", None) host_ipv4 = os.getenv("MY_HOST_IP", None)
host_ipv6 = os.getenv("MY_HOST_IPV6", None) host_ipv6 = os.getenv("MY_HOST_IPV6", None)
host_ip_by_env = host_ipv4 or host_ipv6 host_ip_by_env = host_ipv4 or host_ipv6
...@@ -49,12 +51,12 @@ class WorkerHelper: ...@@ -49,12 +51,12 @@ class WorkerHelper:
host_ip = host_ip_by_env or host_ip_by_sdk host_ip = host_ip_by_env or host_ip_by_sdk
return host_ip return host_ip
def _get_free_port(self): def _get_free_port(self) -> int:
with socket.socket() as sock: with socket.socket() as sock:
sock.bind(("", 0)) sock.bind(("", 0))
return sock.getsockname()[1] return sock.getsockname()[1]
def get_availale_master_addr_port(self): def get_availale_master_addr_port(self) -> Tuple[str, str]:
return self._get_node_ip(), str(self._get_free_port()) return self._get_node_ip(), str(self._get_free_port())
def _get_pid(self): def _get_pid(self):
...@@ -81,16 +83,26 @@ class WorkerMeta: ...@@ -81,16 +83,26 @@ class WorkerMeta:
# we assume that in each WorkerGroup, there is a Master Worker # we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper): class Worker(WorkerHelper):
"""A (distributed) worker."""
_world_size: int
_rank: int
_local_world_size: int
_local_rank: int
_master_addr: str
_master_port: str
_cuda_visible_devices: str
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
instance = super().__new__(cls) instance = super().__new__(cls)
# note that here we use int to distinguish # note that here we use int to distinguish
disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0)) disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
if disable_worker_init: if disable_worker_init:
return instance return instance
rank = os.environ.get("RANK", None) rank = os.getenv("RANK", None)
worker_group_prefix = os.environ.get("WG_PREFIX", None) worker_group_prefix = os.getenv("WG_PREFIX", None)
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
...@@ -112,13 +124,19 @@ class Worker(WorkerHelper): ...@@ -112,13 +124,19 @@ class Worker(WorkerHelper):
def __init__(self, cuda_visible_devices=None) -> None: def __init__(self, cuda_visible_devices=None) -> None:
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
world_size = int(os.environ["WORLD_SIZE"]) world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.environ["RANK"]) rank = int(os.getenv("RANK"))
self._rank = rank self._rank = rank
self._world_size = world_size self._world_size = world_size
master_addr = os.environ["MASTER_ADDR"] if "AMD" in torch.cuda.get_device_name():
master_port = os.environ["MASTER_PORT"] os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices))
master_addr = os.getenv("MASTER_ADDR")
master_port = os.getenv("MASTER_PORT")
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
...@@ -149,6 +167,7 @@ class Worker(WorkerHelper): ...@@ -149,6 +167,7 @@ class Worker(WorkerHelper):
if val is not None: if val is not None:
# print(f"set {key} to {val}") # print(f"set {key} to {val}")
os.environ[key] = str(val) os.environ[key] = str(val)
os.environ["REDIS_STORE_SERVER_HOST"] = ( os.environ["REDIS_STORE_SERVER_HOST"] = (
str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
) )
...@@ -157,7 +176,7 @@ class Worker(WorkerHelper): ...@@ -157,7 +176,7 @@ class Worker(WorkerHelper):
return self._master_addr, self._master_port return self._master_addr, self._master_port
def get_cuda_visible_devices(self): def get_cuda_visible_devices(self):
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set") cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
return cuda_visible_devices return cuda_visible_devices
def print_rank0(self, *args, **kwargs): def print_rank0(self, *args, **kwargs):
......
...@@ -19,20 +19,20 @@ import logging ...@@ -19,20 +19,20 @@ import logging
import signal import signal
import threading import threading
import time import time
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List, Optional
from verl.single_controller.base.decorator import ( from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
MAGIC_ATTR,
Dispatch,
get_predefined_dispatch_fn,
get_predefined_execute_fn,
)
class ResourcePool: class ResourcePool:
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None: """The resource pool with meta info such as world size."""
def __init__(
self, process_on_nodes: Optional[Any] = None, max_collocate_count: int = 10, n_gpus_per_node: int = 8
) -> None:
if process_on_nodes is None: if process_on_nodes is None:
process_on_nodes = [] process_on_nodes = []
self._store = process_on_nodes self._store = process_on_nodes
self.max_collocate_count = max_collocate_count self.max_collocate_count = max_collocate_count
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
...@@ -73,28 +73,23 @@ class ClassWithInitArgs: ...@@ -73,28 +73,23 @@ class ClassWithInitArgs:
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
# def add_arg(self, arg):
# self.args += (arg,)
# def add_kwarg(self, key, value):
# self.kwargs[key] = value
def __call__(self) -> Any: def __call__(self) -> Any:
return self.cls(*self.args, **self.kwargs) return self.cls(*self.args, **self.kwargs)
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
import time
while True: while True:
for worker in workers: for worker in workers:
if not is_alive(worker): if not is_alive(worker):
logging.warning(f"Worker {worker} is not alive, sending signal to main thread") logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
signal.raise_signal(signal.SIGABRT) signal.raise_signal(signal.SIGABRT)
time.sleep(gap_time) time.sleep(gap_time)
class WorkerGroup: class WorkerGroup:
"""A group of workers"""
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
self._is_init_with_detached_workers = True if resource_pool is None else False self._is_init_with_detached_workers = True if resource_pool is None else False
...@@ -136,14 +131,10 @@ class WorkerGroup: ...@@ -136,14 +131,10 @@ class WorkerGroup:
def world_size(self): def world_size(self):
return len(self._workers) return len(self._workers)
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
# MegatronWorkerGroup, XperfWorkerGroup should skip
def _bind_worker_method(self, user_defined_cls, func_generator): def _bind_worker_method(self, user_defined_cls, func_generator):
""" """
Bind the worker method to the WorkerGroup Bind the worker method to the WorkerGroup
""" """
for method_name in dir(user_defined_cls): for method_name in dir(user_defined_cls):
try: try:
method = getattr(user_defined_cls, method_name) method = getattr(user_defined_cls, method_name)
......
...@@ -13,27 +13,28 @@ ...@@ -13,27 +13,28 @@
# limitations under the License. # limitations under the License.
import os import os
import random
import re
import string
import time import time
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch from unittest.mock import patch
import ray import ray
from ray.actor import ActorHandle
from ray.experimental.state.api import get_actor from ray.experimental.state.api import get_actor
from ray.util import list_named_actors from ray.util import list_named_actors
from ray.util.placement_group import PlacementGroup, placement_group from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy
from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup from ..base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
from verl.single_controller.base.decorator import MAGIC_ATTR from ..base.decorator import MAGIC_ATTR
__all__ = ["Worker"] __all__ = ["Worker"]
def get_random_string(length: int) -> str: def get_random_string(length: int) -> str:
import random
import string
letters_digits = string.ascii_letters + string.digits letters_digits = string.ascii_letters + string.digits
return "".join(random.choice(letters_digits) for _ in range(length)) return "".join(random.choice(letters_digits) for _ in range(length))
...@@ -50,6 +51,27 @@ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, block ...@@ -50,6 +51,27 @@ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, block
return func return func
def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
"""
Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK
to be consistent across nodes when resume from checkpoint.
With this function, if there's only one resource pool and there's no node change, RANK should be consistent
across nodes in multiple ray jobs, even if the whole ray cluster is restarted.
"""
node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
pg_ip = {}
for pg in pgs:
specs = ray._private.state.state.placement_group_table(pg.id)
# all bunles should be on the same node
node_id = specs["bundles_to_node_id"][0]
pg_ip[pg.id] = node_ip[node_id]
return sorted(pgs, key=lambda pg: pg_ip[pg.id])
class RayResourcePool(ResourcePool): class RayResourcePool(ResourcePool):
def __init__( def __init__(
self, self,
...@@ -57,7 +79,7 @@ class RayResourcePool(ResourcePool): ...@@ -57,7 +79,7 @@ class RayResourcePool(ResourcePool):
use_gpu: bool = True, use_gpu: bool = True,
name_prefix: str = "", name_prefix: str = "",
max_colocate_count: int = 5, max_colocate_count: int = 5,
detached=False, detached: bool = False,
) -> None: ) -> None:
super().__init__(process_on_nodes, max_colocate_count) super().__init__(process_on_nodes, max_colocate_count)
self.use_gpu = use_gpu self.use_gpu = use_gpu
...@@ -66,7 +88,7 @@ class RayResourcePool(ResourcePool): ...@@ -66,7 +88,7 @@ class RayResourcePool(ResourcePool):
self.pgs = None self.pgs = None
self.detached = detached self.detached = detached
def get_placement_groups(self, strategy="STRICT_PACK", name=None): def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str] = None) -> List[PlacementGroup]:
if self.pgs is not None: if self.pgs is not None:
return self.pgs return self.pgs
...@@ -97,7 +119,7 @@ class RayResourcePool(ResourcePool): ...@@ -97,7 +119,7 @@ class RayResourcePool(ResourcePool):
def extract_pg_from_exist( def extract_pg_from_exist(
resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool
) -> List: ) -> List[PlacementGroup]:
src_pgs = [ src_pgs = [
pg pg
for role_name, resource_pool in resource_pools.items() for role_name, resource_pool in resource_pools.items()
...@@ -151,7 +173,12 @@ class RayClassWithInitArgs(ClassWithInitArgs): ...@@ -151,7 +173,12 @@ class RayClassWithInitArgs(ClassWithInitArgs):
self._options.update(options) self._options.update(options)
def __call__( def __call__(
self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None self,
placement_group: PlacementGroup,
placement_group_bundle_idx: int,
use_gpu: bool = True,
num_gpus: int = 1,
sharing_with: Worker = None,
) -> Any: ) -> Any:
if sharing_with is not None: if sharing_with is not None:
target_node_id = ray.get(sharing_with.get_node_id.remote()) target_node_id = ray.get(sharing_with.get_node_id.remote())
...@@ -188,8 +215,8 @@ class RayWorkerGroup(WorkerGroup): ...@@ -188,8 +215,8 @@ class RayWorkerGroup(WorkerGroup):
ray_cls_with_init: RayClassWithInitArgs = None, ray_cls_with_init: RayClassWithInitArgs = None,
bin_pack: bool = True, bin_pack: bool = True,
name_prefix: str = None, name_prefix: str = None,
detached=False, detached: bool = False,
worker_names=None, worker_names: List[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(resource_pool=resource_pool, **kwargs) super().__init__(resource_pool=resource_pool, **kwargs)
...@@ -210,21 +237,24 @@ class RayWorkerGroup(WorkerGroup): ...@@ -210,21 +237,24 @@ class RayWorkerGroup(WorkerGroup):
if ray_cls_with_init is not None: if ray_cls_with_init is not None:
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
def _is_worker_alive(self, worker: ray.actor.ActorHandle): def _is_worker_alive(self, worker: ActorHandle) -> bool:
worker_state_dict = get_actor(worker._actor_id.hex()) worker_state_dict = get_actor(worker._actor_id.hex())
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
def _init_with_detached_workers(self, worker_names): def _init_with_detached_workers(self, worker_names: List[str]) -> None:
workers = [ray.get_actor(name=name) for name in worker_names] workers = [ray.get_actor(name=name) for name in worker_names]
self._workers = workers self._workers = workers
self._world_size = len(worker_names) self._world_size = len(worker_names)
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): def _init_with_resource_pool(
self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, bin_pack: bool, detached: bool
):
use_gpu = resource_pool.use_gpu use_gpu = resource_pool.use_gpu
strategy = "PACK" strategy = "PACK"
if bin_pack: if bin_pack:
strategy = "STRICT_PACK" strategy = "STRICT_PACK"
pgs = resource_pool.get_placement_groups(strategy=strategy) pgs = resource_pool.get_placement_groups(strategy=strategy)
world_size = resource_pool.world_size world_size = resource_pool.world_size
self._world_size = world_size self._world_size = world_size
...@@ -232,8 +262,8 @@ class RayWorkerGroup(WorkerGroup): ...@@ -232,8 +262,8 @@ class RayWorkerGroup(WorkerGroup):
num_gpus = 1 / resource_pool.max_collocate_count num_gpus = 1 / resource_pool.max_collocate_count
rank = -1 rank = -1
for pg_idx, local_world_size in enumerate(resource_pool.store): local_world_size = resource_pool.store[0]
pg = pgs[pg_idx] for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):
assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the "
for local_rank in range(local_world_size): for local_rank in range(local_world_size):
rank += 1 rank += 1
...@@ -251,8 +281,6 @@ class RayWorkerGroup(WorkerGroup): ...@@ -251,8 +281,6 @@ class RayWorkerGroup(WorkerGroup):
env_vars["MASTER_ADDR"] = self._master_addr env_vars["MASTER_ADDR"] = self._master_addr
env_vars["MASTER_PORT"] = self._master_port env_vars["MASTER_PORT"] = self._master_port
import re
cia_name = type(ray_cls_with_init.cls).__name__ cia_name = type(ray_cls_with_init.cls).__name__
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
...@@ -272,7 +300,7 @@ class RayWorkerGroup(WorkerGroup): ...@@ -272,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
if rank == 0: if rank == 0:
register_center_actor = None register_center_actor = None
for _ in range(120): for _ in range(360):
if f"{self.name_prefix}_register_center" not in list_named_actors(): if f"{self.name_prefix}_register_center" not in list_named_actors():
time.sleep(1) time.sleep(1)
else: else:
......
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
from dataclasses import asdict, dataclass, field, fields, is_dataclass from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from verl.workers.config import WorkerConfig from ..workers.config import WorkerConfig
def recursive_post_init(dataclass_obj): def recursive_post_init(dataclass_obj):
...@@ -36,12 +36,13 @@ class DataConfig: ...@@ -36,12 +36,13 @@ class DataConfig:
train_files: str = "" train_files: str = ""
val_files: str = "" val_files: str = ""
prompt_key: str = "prompt" prompt_key: str = "prompt"
answer_key: str = "answer"
image_key: str = "images"
max_prompt_length: int = 512 max_prompt_length: int = 512
max_response_length: int = 512 max_response_length: int = 512
rollout_batch_size: int = 512 rollout_batch_size: int = 512
return_raw_input_ids: bool = False val_batch_size: int = -1
return_raw_prompt: bool = False system_prompt: Optional[str] = None
system_prompt: str = r"Please reason step by step, and put your final answer within \boxed{}."
shuffle: bool = True shuffle: bool = True
seed: int = 1 seed: int = 1
max_pixels: int = 4194304 max_pixels: int = 4194304
...@@ -52,10 +53,12 @@ class DataConfig: ...@@ -52,10 +53,12 @@ class DataConfig:
class AlgorithmConfig: class AlgorithmConfig:
gamma: float = 1.0 gamma: float = 1.0
lam: float = 1.0 lam: float = 1.0
adv_estimator: str = "gae" adv_estimator: str = "grpo"
disable_kl: bool = False
use_kl_loss: bool = False
kl_penalty: str = "kl" kl_penalty: str = "kl"
kl_type: str = "fixed"
kl_coef: float = 1e-3 kl_coef: float = 1e-3
kl_type: str = "fixed"
kl_horizon: float = 0.0 kl_horizon: float = 0.0
kl_target: float = 0.0 kl_target: float = 0.0
...@@ -67,18 +70,17 @@ class TrainerConfig: ...@@ -67,18 +70,17 @@ class TrainerConfig:
project_name: str = "easy_r1" project_name: str = "easy_r1"
experiment_name: str = "demo" experiment_name: str = "demo"
logger: Tuple[str] = ("console", "wandb") logger: Tuple[str] = ("console", "wandb")
val_generations_to_log_to_wandb: int = 0
nnodes: int = 1 nnodes: int = 1
n_gpus_per_node: int = 8 n_gpus_per_node: int = 8
save_freq: int = -1 critic_warmup: int = 0
load_checkpoint_path: Optional[str] = None val_freq: int = -1
val_before_train: bool = True val_before_train: bool = True
val_only: bool = False val_only: bool = False
test_freq: int = -1 val_generations_to_log: int = 0
critic_warmup: int = 0 save_freq: int = -1
remove_previous_ckpt: bool = False save_limit: int = -1
del_local_ckpt_after_load: bool = False
save_checkpoint_path: Optional[str] = None save_checkpoint_path: Optional[str] = None
load_checkpoint_path: Optional[str] = None
def post_init(self): def post_init(self):
if self.save_checkpoint_path is None: if self.save_checkpoint_path is None:
...@@ -95,6 +97,10 @@ class PPOConfig: ...@@ -95,6 +97,10 @@ class PPOConfig:
def post_init(self): def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length self.worker.rollout.response_length = self.data.max_response_length
self.worker.actor.disable_kl = self.algorithm.disable_kl
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
self.worker.actor.kl_penalty = self.algorithm.kl_penalty
self.worker.actor.kl_coef = self.algorithm.kl_coef
def deep_post_init(self): def deep_post_init(self):
recursive_post_init(self) recursive_post_init(self)
......
# Copyright 2022 The HuggingFace Team
# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,50 +18,55 @@ The function implemented in this file should be used by trainer with different d ...@@ -18,50 +18,55 @@ The function implemented in this file should be used by trainer with different d
implement PPO implement PPO
""" """
from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import verl.utils.torch_functional as verl_F from ..utils import torch_functional as VF
if TYPE_CHECKING: if TYPE_CHECKING:
from verl.trainer.config import AlgorithmConfig from .config import AlgorithmConfig
class AdaptiveKLController: class KLController(ABC):
""" @abstractmethod
Adaptive KL controller described in the paper: def update(self, current_kl: float, n_steps: int) -> None: ...
https://arxiv.org/pdf/1909.08593.pdf
"""
class AdaptiveKLController(KLController):
"""Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf"""
def __init__(self, init_kl_coef: float, target_kl: float, horizon: float): def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
self.value = init_kl_coef self.value = init_kl_coef
self.target = target_kl self.target = target_kl
self.horizon = horizon self.horizon = horizon
def update(self, current_kl, n_steps): def update(self, current_kl: float, n_steps: int) -> None:
target = self.target target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult self.value *= mult
class FixedKLController: class FixedKLController(KLController):
"""Fixed KL controller.""" """Fixed KL controller."""
def __init__(self, kl_coef: float): def __init__(self, init_kl_coef: float):
self.value = kl_coef self.value = init_kl_coef
def update(self, current_kl, n_steps): def update(self, current_kl: float, n_steps: int) -> None:
pass pass
def get_kl_controller(algorithm_config: "AlgorithmConfig"): def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
"""Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319"""
if algorithm_config.kl_type == "fixed": if algorithm_config.kl_type == "fixed":
kl_ctrl = FixedKLController(kl_coef=algorithm_config.kl_coef) kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
elif algorithm_config.kl_type == "adaptive": elif algorithm_config.kl_type == "adaptive":
assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}." assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
kl_ctrl = AdaptiveKLController( kl_ctrl = AdaptiveKLController(
...@@ -70,19 +75,20 @@ def get_kl_controller(algorithm_config: "AlgorithmConfig"): ...@@ -70,19 +75,20 @@ def get_kl_controller(algorithm_config: "AlgorithmConfig"):
horizon=algorithm_config.kl_horizon, horizon=algorithm_config.kl_horizon,
) )
else: else:
raise ValueError("Unknown kl_ctrl type") raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")
return kl_ctrl return kl_ctrl
@torch.no_grad()
def compute_gae_advantage_return( def compute_gae_advantage_return(
token_level_rewards: torch.Tensor, token_level_rewards: torch.Tensor,
values: torch.Tensor, values: torch.Tensor,
eos_mask: torch.Tensor, eos_mask: torch.Tensor,
gamma: torch.Tensor, gamma: torch.Tensor,
lam: torch.Tensor, lam: torch.Tensor,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py """Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513
Args: Args:
token_level_rewards: `(torch.Tensor)` token_level_rewards: `(torch.Tensor)`
...@@ -103,27 +109,26 @@ def compute_gae_advantage_return( ...@@ -103,27 +109,26 @@ def compute_gae_advantage_return(
shape: (bs, response_length) shape: (bs, response_length)
""" """
with torch.no_grad():
lastgaelam = 0 lastgaelam = 0
advantages_reversed = [] advantages_reversed = []
gen_len = token_level_rewards.shape[-1] gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)): for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam) advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values advantages = torch.stack(advantages_reversed[::-1], dim=1)
advantages = verl_F.masked_whiten(advantages, eos_mask) returns = (advantages + values) * eos_mask
advantages = VF.masked_whiten(advantages, eos_mask) * eos_mask
return advantages, returns return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@torch.no_grad()
def compute_grpo_outcome_advantage( def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6 token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Compute advantage for GRPO, operating only on Outcome reward Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response). (with only one scalar reward for each response).
...@@ -141,15 +146,13 @@ def compute_grpo_outcome_advantage( ...@@ -141,15 +146,13 @@ def compute_grpo_outcome_advantage(
""" """
response_length = token_level_rewards.shape[-1] response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1) scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list) id2score = defaultdict(list)
id2mean = {} id2mean, id2std = {}, {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0] bsz = scores.shape[0]
for i in range(bsz): for i in range(bsz):
id2score[index[i]].append(scores[i]) id2score[index[i]].append(scores[i])
for idx in id2score: for idx in id2score:
if len(id2score[idx]) == 1: if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0) id2mean[idx] = torch.tensor(0.0)
...@@ -159,16 +162,64 @@ def compute_grpo_outcome_advantage( ...@@ -159,16 +162,64 @@ def compute_grpo_outcome_advantage(
id2std[idx] = torch.std(torch.tensor([id2score[idx]])) id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else: else:
raise ValueError(f"no score in prompt index: {idx}") raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz): for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
@torch.no_grad()
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}.")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
response_num - 1
)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores return scores, scores
@torch.no_grad()
def compute_reinforce_plus_plus_outcome_advantage( def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Compute advantage for REINFORCE++. Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262 This implementation is based on the paper: https://arxiv.org/abs/2501.03262
...@@ -184,26 +235,24 @@ def compute_reinforce_plus_plus_outcome_advantage( ...@@ -184,26 +235,24 @@ def compute_reinforce_plus_plus_outcome_advantage(
Returns: `(torch.Tensor)` Returns: `(torch.Tensor)`
shape: (bs, response_length) shape: (bs, response_length)
""" """
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards) returns = torch.zeros_like(token_level_rewards)
running_return = 0 running_return = 0
for t in reversed(range(token_level_rewards.shape[1])): for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return returns[:, t] = running_return
# Reset after EOS # Reset after EOS
running_return = running_return * eos_mask[:, t] running_return = running_return * eos_mask[:, t]
advantages = verl_F.masked_whiten(returns, eos_mask) advantages = VF.masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask advantages *= eos_mask
returns *= eos_mask
return advantages, returns return advantages, returns
@torch.no_grad()
def compute_remax_outcome_advantage( def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Compute advantage for ReMax, operating only on Outcome reward Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505 This implementation is based on the paper: https://arxiv.org/abs/2310.10505
...@@ -225,23 +274,31 @@ def compute_remax_outcome_advantage( ...@@ -225,23 +274,31 @@ def compute_remax_outcome_advantage(
""" """
response_length = token_level_rewards.shape[-1] response_length = token_level_rewards.shape[-1]
# scores = token_level_rewards.sum(dim=-1) # scores = token_level_rewards.sum(dim=-1)
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) * eos_mask
with torch.no_grad():
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
return advantages, returns return advantages, returns
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): def compute_rewards(
kl = old_log_prob - ref_log_prob token_level_scores: torch.Tensor,
log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
kl_ratio: float,
) -> torch.Tensor:
kl = log_probs - ref_log_probs
return token_level_scores - kl * kl_ratio return token_level_scores - kl * kl_ratio
def compute_policy_loss( def compute_policy_loss(
old_log_prob, log_prob, advantages, eos_mask, cliprange old_log_probs: torch.Tensor,
log_probs: torch.Tensor,
advantages: torch.Tensor,
eos_mask: torch.Tensor,
cliprange: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 """Compute the policy loss.
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
Args: Args:
old_log_prob: `(torch.Tensor)` old_log_prob: `(torch.Tensor)`
...@@ -260,95 +317,88 @@ def compute_policy_loss( ...@@ -260,95 +317,88 @@ def compute_policy_loss(
policy gradient loss computed via PPO policy gradient loss computed via PPO
pg_clipfrac: (float) pg_clipfrac: (float)
a float number indicating the fraction of policy gradient loss being clipped a float number indicating the fraction of policy gradient loss being clipped
""" """
negative_approx_kl = log_prob - old_log_prob negative_approx_kl = log_probs - old_log_probs
# clamp the ratio before exp to avoid nan
# see: https://github.com/pytorch/pytorch/issues/10729
ratio = torch.exp(negative_approx_kl) ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask) clipped_ratio = torch.exp(torch.clamp(negative_approx_kl, np.log(1.0 - cliprange), np.log(1.0 + cliprange)))
ppo_kl = VF.masked_mean(-negative_approx_kl, eos_mask)
pg_losses = -advantages * ratio pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) pg_losses2 = -advantages * clipped_ratio
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask) pg_loss = VF.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask) pg_clipfrac = VF.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl return pg_loss, pg_clipfrac, ppo_kl
def compute_entropy_loss(logits, eos_mask): def compute_value_loss(
"""Compute Categorical entropy loss vpreds: torch.Tensor,
returns: torch.Tensor,
Args: values: torch.Tensor,
logits: `(torch.Tensor)` eos_mask: torch.Tensor,
shape: (bs, response_length, vocab_size) cliprange_value: float,
eos_mask: `(torch.Tensor)` ) -> Tuple[torch.Tensor, float]:
shape: (bs, response_length) """Compute the value loss.
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
return entropy_loss
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value): Copied from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args: Args:
vpreds (`torch.FloatTensor`): vpreds (`torch.FloatTensor`):
Predicted values of the value head, shape (`batch_size`, `response_length`) Predicted values of the value head, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
returns: (`torch.FloatTensor`): returns: (`torch.FloatTensor`):
Ground truth returns, shape (`batch_size`, `response_length`) Ground truth returns, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange_value: (float)
The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
Returns: Returns:
vf_loss: a scalar (`torch.FloatTensor`): vf_loss: a scalar (`torch.FloatTensor`):
value function loss value function loss
vf_clipfrac: a float vf_clipfrac: a float
The ratio of vf being clipped The ratio of vf being clipped
""" """
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2 vf_losses1 = torch.square(vpreds - returns)
vf_losses2 = (vpredclipped - returns) ** 2 vf_losses2 = torch.square(vpredclipped - returns)
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask) vf_loss = 0.5 * VF.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask) vf_clipfrac = VF.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
return vf_loss, vf_clipfrac return vf_loss, vf_clipfrac
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.Tensor: def kl_penalty(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
"""Compute KL divergence given logprob and ref_logprob. """Compute KL divergence given log_probs and ref_log_probs.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
Args: Args:
logprob: log_probs: torch.Tensor
ref_logprob: ref_log_probs: torch.Tensor
Returns: Returns:
kl_div: torch.Tensor
""" """
log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
if kl_penalty == "kl": if kl_penalty == "kl":
return logprob - ref_logprob return log_probs - ref_log_probs
if kl_penalty == "abs": if kl_penalty == "abs":
return (logprob - ref_logprob).abs() return (log_probs - ref_log_probs).abs()
if kl_penalty == "mse": if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square() return 0.5 * (log_probs - ref_log_probs).square()
# J. Schulman. Approximating kl divergence, 2020. # J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html. # URL http://joschu.net/blog/kl-approx.html
if kl_penalty == "low_var_kl": if kl_penalty == "low_var_kl":
kl = ref_logprob - logprob kl = ref_log_probs - log_probs
ratio = torch.exp(kl) kld = (kl.exp() - kl - 1).contiguous()
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10) return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full": if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
raise NotImplementedError
raise NotImplementedError raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")
...@@ -16,41 +16,40 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by o ...@@ -16,41 +16,40 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by o
""" """
import json import json
import torch
import ray import ray
from omegaconf import OmegaConf
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.config import PPOConfig
from verl.trainer.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
from verl.utils import get_processor, get_tokenizer
from verl.workers.fsdp_workers import FSDPWorker
from verl.workers.reward import CustomRewardManager
from omegaconf import OmegaConf
def main(): from ..single_controller.ray import RayWorkerGroup
cli_args = OmegaConf.from_cli() from ..utils.tokenizer import get_processor, get_tokenizer
file_config = OmegaConf.load(cli_args.config) from ..workers.fsdp_workers import FSDPWorker
del cli_args.config from ..workers.reward import CustomRewardManager
from .config import PPOConfig
default_config = OmegaConf.structured(PPOConfig()) from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
ppo_config = OmegaConf.merge(default_config, file_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config)
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.get(main_task.remote(ppo_config))
# please make sure main_task is not scheduled on head
@ray.remote(num_cpus=1)
class Runner:
"""A runner for RL training."""
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head def run(self, config: PPOConfig):
def main_task(config: PPOConfig): # print config
config.deep_post_init() config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2)) print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer # instantiate tokenizer
tokenizer = get_tokenizer(config.worker.actor.model.model_path) tokenizer = get_tokenizer(
processor = get_processor(config.worker.actor.model.model_path, use_fast=True) config.worker.actor.model.model_path,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
processor = get_processor(
config.worker.actor.model.model_path,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
# define worker classes # define worker classes
ray_worker_group_cls = RayWorkerGroup ray_worker_group_cls = RayWorkerGroup
...@@ -59,7 +58,6 @@ def main_task(config: PPOConfig): ...@@ -59,7 +58,6 @@ def main_task(config: PPOConfig):
Role.Critic: ray.remote(FSDPWorker), Role.Critic: ray.remote(FSDPWorker),
Role.RefPolicy: ray.remote(FSDPWorker), Role.RefPolicy: ray.remote(FSDPWorker),
} }
global_pool_id = "global_pool" global_pool_id = "global_pool"
resource_pool_spec = { resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
...@@ -69,16 +67,11 @@ def main_task(config: PPOConfig): ...@@ -69,16 +67,11 @@ def main_task(config: PPOConfig):
Role.Critic: global_pool_id, Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id, Role.RefPolicy: global_pool_id,
} }
reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
val_reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
trainer = RayPPOTrainer( trainer = RayPPOTrainer(
config=config, config=config,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -93,5 +86,31 @@ def main_task(config: PPOConfig): ...@@ -93,5 +86,31 @@ def main_task(config: PPOConfig):
trainer.fit() trainer.fit()
def main():
cli_args = OmegaConf.from_cli()
default_config = OmegaConf.structured(PPOConfig())
if hasattr(cli_args, "config"):
config_path = cli_args.pop("config", None)
file_config = OmegaConf.load(config_path)
default_config = OmegaConf.merge(default_config, file_config)
ppo_config = OmegaConf.merge(default_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config)
# this is for local ray cluster
if not ray.is_initialized():
# for rocm
if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(),
ignore_reinit_error=True,
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
else:
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
runner = Runner.remote()
ray.get(runner.run.remote(ppo_config))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List
import numpy as np
import torch
from ..protocol import DataProto
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
return {key: np.mean(value) for key, value in metrics.items()}
def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]:
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
max_response_length = batch.batch["responses"].size(-1)
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float()
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
num_response_tokens = torch.sum(batch.batch["response_mask"]).item()
num_overall_tokens = sum(batch.meta_info["global_token_num"])
num_tokens_of_section = {
**dict.fromkeys(["gen", "reward"], num_response_tokens),
**dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens),
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
total_num_tokens = sum(batch.meta_info["global_token_num"])
time = timing_raw["step"]
return {
"perf/total_num_tokens": total_num_tokens,
"perf/time_per_step": time,
"perf/throughput": total_num_tokens / (time * n_gpus),
}
This diff is collapsed.
...@@ -11,8 +11,3 @@ ...@@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .tokenizer import get_processor, get_tokenizer
__all__ = ["get_processor", "get_tokenizer"]
...@@ -11,3 +11,8 @@ ...@@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt
__all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"]
...@@ -11,20 +11,27 @@ ...@@ -11,20 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import random import random
import re
import shutil import shutil
import tempfile import tempfile
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed import torch.distributed as dist
from filelock import FileLock from filelock import FileLock
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
class BaseCheckpointManager: CHECKPOINT_TRACKER = "latest_global_step.txt"
class BaseCheckpointManager(ABC):
""" """
A checkpoint manager that saves and loads A checkpoint manager that saves and loads
- model - model
...@@ -44,42 +51,27 @@ class BaseCheckpointManager: ...@@ -44,42 +51,27 @@ class BaseCheckpointManager:
model: FSDP, model: FSDP,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer, processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
processor: ProcessorMixin
): ):
self.previous_global_step = None
self.previous_save_local_path = None
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.tokenizer = tokenizer self.processing_class = processing_class
self.processor = processor
assert isinstance(self.model, FSDP) assert isinstance(self.model, FSDP)
self.rank = torch.distributed.get_rank() self.rank = dist.get_rank()
self.world_size = torch.distributed.get_world_size() self.world_size = dist.get_world_size()
@abstractmethod
def load_checkpoint(self, *args, **kwargs): def load_checkpoint(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@abstractmethod
def save_checkpoint(self, *args, **kwargs): def save_checkpoint(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def remove_previous_save_local_path(self):
if not self.previous_save_local_path:
return
abs_path = os.path.abspath(self.previous_save_local_path)
print(f"Checkpoint manager remove previous save local path: {abs_path}")
if not os.path.exists(abs_path):
return
# remove previous local_path
shutil.rmtree(abs_path, ignore_errors=True)
@staticmethod @staticmethod
def local_mkdir(path): def local_mkdir(path: str) -> str:
if not os.path.isabs(path): if not os.path.isabs(path):
working_dir = os.getcwd() working_dir = os.getcwd()
path = os.path.join(working_dir, path) path = os.path.join(working_dir, path)
...@@ -89,18 +81,16 @@ class BaseCheckpointManager: ...@@ -89,18 +81,16 @@ class BaseCheckpointManager:
lock_path = os.path.join(tempfile.gettempdir(), lock_filename) lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
try: try:
with FileLock(lock_path, timeout=60): # Add timeout with FileLock(lock_path, timeout=60):
# make a new dir
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
except Exception as e: except Exception as e:
print(f"Warning: Failed to acquire lock for {path}: {e}") print(f"Warning: Failed to acquire lock for {path}: {e}")
# Even if the lock is not acquired, try to create the directory os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory
os.makedirs(path, exist_ok=True)
return path return path
@staticmethod @staticmethod
def get_rng_state(): def get_rng_state() -> Dict[str, Any]:
rng_state = { rng_state = {
"cpu": torch.get_rng_state(), "cpu": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state(), "cuda": torch.cuda.get_rng_state(),
...@@ -110,14 +100,14 @@ class BaseCheckpointManager: ...@@ -110,14 +100,14 @@ class BaseCheckpointManager:
return rng_state return rng_state
@staticmethod @staticmethod
def load_rng_state(rng_state): def load_rng_state(rng_state: Dict[str, Any]):
torch.set_rng_state(rng_state["cpu"]) torch.set_rng_state(rng_state["cpu"])
torch.cuda.set_rng_state(rng_state["cuda"]) torch.cuda.set_rng_state(rng_state["cuda"])
np.random.set_state(rng_state["numpy"]) np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["random"]) random.setstate(rng_state["random"])
def find_latest_ckpt_path(path, directory_format="global_step_{}"): def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]:
if path is None: if path is None:
return None return None
...@@ -128,6 +118,7 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"): ...@@ -128,6 +118,7 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
with open(tracker_file, "rb") as f: with open(tracker_file, "rb") as f:
iteration = int(f.read().decode()) iteration = int(f.read().decode())
ckpt_path = os.path.join(path, directory_format.format(iteration)) ckpt_path = os.path.join(path, directory_format.format(iteration))
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
print("Checkpoint does not exist: %s", ckpt_path) print("Checkpoint does not exist: %s", ckpt_path)
...@@ -137,8 +128,33 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"): ...@@ -137,8 +128,33 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
return ckpt_path return ckpt_path
def get_checkpoint_tracker_filename(root_path: str): def get_checkpoint_tracker_filename(root_path: str) -> str:
""" """
Tracker file rescords the latest chckpoint during training to restart from. Tracker file rescords the latest chckpoint during training to restart from.
""" """
return os.path.join(root_path, "latest_checkpointed_iteration.txt") return os.path.join(root_path, CHECKPOINT_TRACKER)
def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"):
"""
Remove the obsolete checkpoints that exceed the save_limit.
"""
if save_limit <= 0:
return
if not os.path.exists(path):
return
pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)")
ckpt_folders = []
for folder in os.listdir(path):
if match := re.match(pattern, folder):
step = int(match.group(1))
if step < global_step:
ckpt_folders.append((step, folder))
ckpt_folders.sort(reverse=True)
for _, folder in ckpt_folders[save_limit - 1 :]:
folder_path = os.path.join(path, folder)
shutil.rmtree(folder_path, ignore_errors=True)
print(f"Removed obsolete checkpoint: {folder_path}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment