Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
......@@ -12,11 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from setuptools import find_packages, setup
def get_requires():
def get_version() -> str:
with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"__version__\W*=\W*\"([^\"]+)\""
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
......@@ -31,18 +41,19 @@ extra_require = {
def main():
setup(
name="verl",
version="0.2.0.dev0",
package_dir={"": "."},
packages=find_packages(where="."),
url="https://github.com/volcengine/verl",
license="Apache 2.0",
version=get_version(),
description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
author="verl",
author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
description="",
license="Apache 2.0 License",
url="https://github.com/volcengine/verl",
package_dir={"": "."},
packages=find_packages(where="."),
python_requires=">=3.9.0",
install_requires=get_requires(),
extras_require=extra_require,
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
)
......
......@@ -12,8 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .protocol import DataProto
__all__ = ["DataProto"]
__version__ = "0.2.0.dev"
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .transformers.flash_attention_utils import flash_attention_forward
from .transformers.qwen2_vl import qwen2_vl_attn_forward
def apply_ulysses_patch(model_type: str) -> None:
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
else:
raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
from ...utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
_flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def prepare_fa2_from_position_ids(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seqlens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
def _custom_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
query_length: int,
is_causal: bool = True,
position_ids: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
deterministic: Optional[bool] = None,
**kwargs,
):
"""
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if _flash_supports_deterministic:
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
if kwargs.get("softcap") is not None:
flash_kwargs["softcap"] = kwargs.pop("softcap")
query_states, key_states, value_states = fa_peft_integration_check(
query_states, key_states, value_states, target_dtype=torch.bfloat16
)
sp_size = get_ulysses_sequence_parallel_world_size()
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids[0]
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=kwargs.pop("dropout", 0.0),
softmax_scale=kwargs.pop("softmax_scale", None),
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_length,
is_causal=is_causal,
sliding_window=sliding_window,
use_top_left_mask=use_top_left_mask,
deterministic=deterministic,
**kwargs,
) # do not pass position_ids to old flash_attention_forward
if sp_size > 1:
# (batch_size, seq_length, num_head, head_size)
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
return attn_output
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
# This is before the transpose
q_len = query.shape[2]
# FA2 uses non-transposed inputs
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
attn_output = _custom_flash_attention_forward(
query,
key,
value,
attention_mask,
query_length=q_len,
is_causal=True,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_flash_use_top_left_mask,
**kwargs,
)
return attn_output, None
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Based on:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from .flash_attention_utils import flash_attention_forward
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLAttention,
apply_multimodal_rotary_pos_emb,
repeat_kv,
)
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
except ImportError:
pass
def get_rope_index(
processor: "Qwen2VLProcessor",
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
video_grid_thw: Optional[torch.Tensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
"""
spatial_merge_size = processor.image_processor.merge_size
tokens_per_second = 2
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
image_index, video_index = 0, 0
input_ids = input_ids[attention_mask == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
else:
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
return position_ids
def qwen2_vl_attn_forward(
self: "Qwen2VLAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, None, None]:
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Because the input can be padded, the absolute sequence length depends on the max position id.
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
attn_output, _ = flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=dropout_rate,
sliding_window=sliding_window,
position_ids=position_ids, # important: pass position ids
) # (batch_size, seq_length, num_head / sp_size, head_size)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None, None
......@@ -26,13 +26,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import ray
import torch
import torch.distributed as dist
from numpy.typing import NDArray
from tensordict import TensorDict
from torch.distributed import ProcessGroup
from torch.utils.data import DataLoader
from verl.utils.py_functional import union_two_dict
from .utils.py_functional import union_two_dict
try:
......@@ -89,21 +88,22 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
)
for key, value in tensor_dict2.items():
if key in tensor_dict1 and not torch.equal(tensor_dict1[key], value):
for key in tensor_dict2.keys():
if key in tensor_dict1 and not torch.equal(tensor_dict1[key], tensor_dict2[key]):
raise ValueError(f"Key already exists: {key}.")
tensor_dict1[key] = value
tensor_dict1[key] = tensor_dict2[key]
return tensor_dict1
def union_numpy_dict(
tensor_dict1: Dict[str, Union[List, NDArray]], tensor_dict2: Dict[str, Union[List, NDArray]]
) -> Dict[str, Union[List, NDArray]]:
for key, value in tensor_dict2.items():
if key in tensor_dict1 and isinstance(value, np.ndarray) and not np.all(tensor_dict1[key] == value):
raise ValueError(f"Key already exists: {key}.")
def union_numpy_dict(tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray]) -> Dict[str, NDArray]:
for key in tensor_dict2.keys():
if key in tensor_dict1:
assert isinstance(tensor_dict2[key], np.ndarray)
assert isinstance(tensor_dict1[key], np.ndarray)
if not np.all(tensor_dict1[key] == tensor_dict2[key]):
raise ValueError(f"Key already exists: {key}.")
tensor_dict1[key] = tensor_dict2[key]
......@@ -151,6 +151,7 @@ def collate_fn(data_items: list["DataProtoItem"]):
batch = torch.stack(batch).contiguous()
non_tensor_batch = batch_collate(non_tensor_batch)
non_tensor_batch = {key: np.array(value, dtype=object) for key, value in non_tensor_batch.items()}
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
......@@ -189,7 +190,8 @@ class DataProto:
def __getitem__(self, item):
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self):
buffer = io.BytesIO()
......@@ -229,8 +231,7 @@ class DataProto:
size_of_numpy_array = 0
for value in self.non_tensor_batch.values():
if isinstance(value, np.ndarray):
size_of_numpy_array += value.nbytes
size_of_numpy_array += value.nbytes
size_of_numpy_array /= 1024**3
size_of_tensordict /= 1024**3
......@@ -254,13 +255,13 @@ class DataProto:
assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}."
@classmethod
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, list, NDArray]], meta_info=None):
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, NDArray]], meta_info=None):
tensors = {}
non_tensors = {}
for key, value in data.items():
if isinstance(value, torch.Tensor):
tensors[key] = value
elif isinstance(value, (list, np.ndarray)):
elif isinstance(value, np.ndarray):
non_tensors[key] = value
else:
raise ValueError(f"Unsupported type in data {type(value)}")
......@@ -472,8 +473,6 @@ class DataProto:
assert len(self) % chunks == 0, (
f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}."
)
chunk_size = len(self) // chunks
if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
else:
......@@ -481,12 +480,8 @@ class DataProto:
non_tensor_batch_lst = [{} for _ in range(chunks)]
for key, value in self.non_tensor_batch.items():
assert isinstance(value, (list, np.ndarray))
if isinstance(value, np.ndarray):
non_tensor_lst = np.array_split(value, chunks)
else:
non_tensor_lst = [value[i : i + chunk_size] for i in range(0, len(self), chunk_size)]
assert isinstance(value, np.ndarray)
non_tensor_lst = np.array_split(value, chunks)
assert len(non_tensor_lst) == chunks
for i in range(chunks):
non_tensor_batch_lst[i][key] = non_tensor_lst[i]
......@@ -524,12 +519,7 @@ class DataProto:
non_tensor_batch = batch_collate([d.non_tensor_batch for d in data])
for key, value in non_tensor_batch.items():
if isinstance(value[0], np.ndarray):
non_tensor_batch[key] = np.concatenate(value, axis=0)
else:
non_tensor_batch[key] = []
for item in value:
non_tensor_batch[key].extend(item)
non_tensor_batch[key] = np.concatenate(value, axis=0)
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
......@@ -574,16 +564,10 @@ class DataProto:
repeated_non_tensor_batch = {}
for key, value in self.non_tensor_batch.items():
if isinstance(value, np.ndarray):
if interleave:
repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
else:
repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
if interleave:
repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
else:
if interleave:
repeated_non_tensor_batch[key] = [item for item in value for _ in range(repeat_times)]
else:
repeated_non_tensor_batch[key] = [item for _ in range(repeat_times) for item in value]
repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
return DataProto(
batch=repeated_batch,
......@@ -591,39 +575,6 @@ class DataProto:
meta_info=self.meta_info,
)
def broadcast(self, src: int, group: Optional[ProcessGroup] = None):
for key in self.batch.sorted_keys:
dist.broadcast(self.batch[key], src=src, group=group, async_op=False)
object_list = [self.non_tensor_batch]
dist.broadcast_object_list(object_list, src=src, group=group)
self.non_tensor_batch = object_list[0]
def all_gather(self, group: Optional[ProcessGroup] = None):
world_size = dist.get_world_size(group)
output = {}
for key in self.batch.sorted_keys:
value = self.batch[key].contiguous()
output[key] = [torch.empty_like(value) for _ in range(world_size)]
dist.all_gather(output[key], value, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=0)
self.batch = TensorDict(output, batch_size=self.batch.batch_size[0] * world_size)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(world_size)]
dist.all_gather_object(all_non_tensor_batch, self.non_tensor_batch, group=group)
non_tensor_batch = defaultdict(list)
for key, value in self.non_tensor_batch.items():
if isinstance(value, np.ndarray):
non_tensor_batch[key] = np.concatenate([batch[key] for batch in all_non_tensor_batch])
else:
for batch in all_non_tensor_batch:
non_tensor_batch[key].extend(batch[key])
self.non_tensor_batch = non_tensor_batch
self.check_consistency()
@dataclass
class DataProtoFuture:
......@@ -664,10 +615,53 @@ class DataProtoFuture:
return arg_future_lst
def get(self):
output = ray.get(self.futures) # dp_size.
for o in output:
assert isinstance(o, DataProto)
output = self.collect_fn(output) # select dp, concat
outputs = ray.get(self.futures) # dp_size.
for output in outputs:
assert isinstance(output, DataProto)
outputs = self.collect_fn(outputs) # select dp, concat
if self.dispatch_fn is not None:
output = self.dispatch_fn(output) # split in batch dim, select using dp
return output
outputs = self.dispatch_fn(outputs) # split in batch dim, select using dp
return outputs
def allgather_dict_tensors(
tensors: Union[Dict[str, torch.Tensor], TensorDict], size: int, group: ProcessGroup, dim: int = 0
) -> Union[Dict[str, torch.Tensor], TensorDict]:
"""
TODO: optimize this.
- We can use async ops
- We can use only one allgather
"""
if isinstance(tensors, TensorDict):
is_tensor_dict = True
tensors_as_dict = tensors.to_dict()
else:
tensors_as_dict = tensors
is_tensor_dict = False
output = {}
sorted_keys = sorted(tensors_as_dict.keys())
for key in sorted_keys:
val = tensors_as_dict[key]
output[key] = [torch.empty_like(val) for _ in range(size)]
torch.distributed.all_gather(output[key], val, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=dim)
if is_tensor_dict:
output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)
return output
def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> None:
# Note that this is an inplace operator just like torch.distributed.all_gather
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
all_non_tensor_batch = [None for _ in range(size)]
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
......@@ -14,3 +14,6 @@
from .worker import Worker
from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
__all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
......@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from enum import Enum, auto
from functools import wraps
from types import FunctionType
from typing import Dict, List
from typing import TYPE_CHECKING, Dict, List, Literal, Union
import ray
from verl.protocol import DataProto, DataProtoFuture
from ...protocol import DataProto, DataProtoFuture
if TYPE_CHECKING:
from .worker_group import WorkerGroup
# here we add a magic number of avoid user-defined function already have this attribute
......@@ -27,13 +31,13 @@ MAGIC_ATTR = "attrs_3141562937"
class Dispatch(Enum):
RANK_ZERO = 0
ONE_TO_ALL = 1
ALL_TO_ALL = 2
DP_COMPUTE = 3
DP_COMPUTE_PROTO = 4
DP_COMPUTE_PROTO_WITH_FUNC = 5
DP_COMPUTE_METRIC = 6
RANK_ZERO = auto()
ONE_TO_ALL = auto()
ALL_TO_ALL = auto()
DP_COMPUTE = auto()
DP_COMPUTE_PROTO = auto()
DP_COMPUTE_PROTO_WITH_FUNC = auto()
DP_COMPUTE_METRIC = auto()
class Execute(Enum):
......@@ -41,100 +45,85 @@ class Execute(Enum):
RANK_ZERO = 1
def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
splitted_args = []
for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture))
splitted_args.append(arg.chunk(chunks=chunks))
splitted_kwargs = {}
for key, val in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture))
splitted_kwargs[key] = val.chunk(chunks=chunks)
for key, value in kwargs.items():
assert isinstance(value, (DataProto, DataProtoFuture))
splitted_kwargs[key] = value.chunk(chunks=chunks)
return splitted_args, splitted_kwargs
def dispatch_one_to_all(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
args = tuple([arg] * worker_group.world_size for arg in args)
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
return args, kwargs
def dispatch_all_to_all(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
return args, kwargs
def collect_all_to_all(worker_group, output):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def collect_all_to_all(worker_group: "WorkerGroup", output):
return output
def _concat_data_proto_or_future(output: List):
def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
# make sure all the elements in output has the same type
for o in output:
assert type(o) is type(output[0])
for output in outputs:
assert type(output) is type(outputs[0])
o = output[0]
output = outputs[0]
if isinstance(o, DataProto):
return DataProto.concat(output)
elif isinstance(o, ray.ObjectRef):
return DataProtoFuture.concat(output)
if isinstance(output, DataProto):
return DataProto.concat(outputs)
elif isinstance(output, ray.ObjectRef):
return DataProtoFuture.concat(outputs)
else:
raise NotImplementedError
def dispatch_dp_compute(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
return args, kwargs
def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
for arg in args:
assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
for value in kwargs.values():
assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
def collect_dp_compute(worker_group, output):
from verl.single_controller.base.worker_group import WorkerGroup
return args, kwargs
assert isinstance(worker_group, WorkerGroup)
assert len(output) == worker_group.world_size
return output
def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
assert len(outputs) == worker_group.world_size
return outputs
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
return splitted_args, splitted_kwargs
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
return splitted_args_with_func, splitted_kwargs
def collect_dp_compute_data_proto(worker_group, output):
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
for output in outputs:
assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
output = collect_dp_compute(worker_group, output)
return _concat_data_proto_or_future(output)
outputs = collect_dp_compute(worker_group, outputs)
return _concat_data_proto_or_future(outputs)
def get_predefined_dispatch_fn(dispatch_mode):
def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
predefined_dispatch_mode_fn = {
Dispatch.ONE_TO_ALL: {
"dispatch_fn": dispatch_one_to_all,
......@@ -144,6 +133,10 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_all_to_all,
"collect_fn": collect_all_to_all,
},
Dispatch.DP_COMPUTE: {
"dispatch_fn": dispatch_dp_compute,
"collect_fn": collect_dp_compute,
},
Dispatch.DP_COMPUTE_PROTO: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute_data_proto,
......@@ -152,11 +145,15 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_dp_compute_data_proto_with_func,
"collect_fn": collect_dp_compute_data_proto,
},
Dispatch.DP_COMPUTE_METRIC: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute,
},
}
return predefined_dispatch_mode_fn[dispatch_mode]
def get_predefined_execute_fn(execute_mode):
def get_predefined_execute_fn(execute_mode: Execute):
"""
Note that here we only asks execute_all and execute_rank_zero to be implemented
Leave the choice of how these two functions handle argument 'blocking' to users
......@@ -168,17 +165,17 @@ def get_predefined_execute_fn(execute_mode):
return predefined_execute_mode_fn[execute_mode]
def _check_dispatch_mode(dispatch_mode):
assert isinstance(dispatch_mode, (Dispatch, Dict)), (
def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
assert isinstance(dispatch_mode, (Dispatch, dict)), (
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
)
if isinstance(dispatch_mode, Dict):
if isinstance(dispatch_mode, dict):
necessary_keys = ["dispatch_fn", "collect_fn"]
for key in necessary_keys:
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
def _check_execute_mode(execute_mode):
def _check_execute_mode(execute_mode: Execute):
assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
......@@ -189,9 +186,10 @@ def _materialize_futures(*args, **kwargs):
arg = arg.get()
# add more type to materialize
new_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, DataProtoFuture):
kwargs[k] = v.get()
for key, value in kwargs.items():
if isinstance(value, DataProtoFuture):
kwargs[key] = value.get()
new_args = tuple(new_args)
return new_args, kwargs
......
......@@ -18,11 +18,13 @@ the class for Worker
import os
import socket
from dataclasses import dataclass
from typing import Tuple
import ray
import torch
from verl.single_controller.base.decorator import Dispatch, Execute, register
from verl.single_controller.base.register_center.ray import create_worker_group_register_center
from .decorator import Dispatch, Execute, register
from .register_center.ray import create_worker_group_register_center
@dataclass
......@@ -40,7 +42,7 @@ class DistGlobalInfo:
class WorkerHelper:
def _get_node_ip(self):
def _get_node_ip(self) -> str:
host_ipv4 = os.getenv("MY_HOST_IP", None)
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
host_ip_by_env = host_ipv4 or host_ipv6
......@@ -49,12 +51,12 @@ class WorkerHelper:
host_ip = host_ip_by_env or host_ip_by_sdk
return host_ip
def _get_free_port(self):
def _get_free_port(self) -> int:
with socket.socket() as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
def get_availale_master_addr_port(self):
def get_availale_master_addr_port(self) -> Tuple[str, str]:
return self._get_node_ip(), str(self._get_free_port())
def _get_pid(self):
......@@ -81,16 +83,26 @@ class WorkerMeta:
# we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper):
"""A (distributed) worker."""
_world_size: int
_rank: int
_local_world_size: int
_local_rank: int
_master_addr: str
_master_port: str
_cuda_visible_devices: str
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
# note that here we use int to distinguish
disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0))
disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
if disable_worker_init:
return instance
rank = os.environ.get("RANK", None)
worker_group_prefix = os.environ.get("WG_PREFIX", None)
rank = os.getenv("RANK", None)
worker_group_prefix = os.getenv("WG_PREFIX", None)
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
......@@ -112,13 +124,19 @@ class Worker(WorkerHelper):
def __init__(self, cuda_visible_devices=None) -> None:
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.getenv("RANK"))
self._rank = rank
self._world_size = world_size
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
if "AMD" in torch.cuda.get_device_name():
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices))
master_addr = os.getenv("MASTER_ADDR")
master_port = os.getenv("MASTER_PORT")
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
......@@ -149,6 +167,7 @@ class Worker(WorkerHelper):
if val is not None:
# print(f"set {key} to {val}")
os.environ[key] = str(val)
os.environ["REDIS_STORE_SERVER_HOST"] = (
str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
)
......@@ -157,7 +176,7 @@ class Worker(WorkerHelper):
return self._master_addr, self._master_port
def get_cuda_visible_devices(self):
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
return cuda_visible_devices
def print_rank0(self, *args, **kwargs):
......
......@@ -19,20 +19,20 @@ import logging
import signal
import threading
import time
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional
from verl.single_controller.base.decorator import (
MAGIC_ATTR,
Dispatch,
get_predefined_dispatch_fn,
get_predefined_execute_fn,
)
from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
class ResourcePool:
def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
"""The resource pool with meta info such as world size."""
def __init__(
self, process_on_nodes: Optional[Any] = None, max_collocate_count: int = 10, n_gpus_per_node: int = 8
) -> None:
if process_on_nodes is None:
process_on_nodes = []
self._store = process_on_nodes
self.max_collocate_count = max_collocate_count
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
......@@ -73,28 +73,23 @@ class ClassWithInitArgs:
self.args = args
self.kwargs = kwargs
# def add_arg(self, arg):
# self.args += (arg,)
# def add_kwarg(self, key, value):
# self.kwargs[key] = value
def __call__(self) -> Any:
return self.cls(*self.args, **self.kwargs)
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
import time
while True:
for worker in workers:
if not is_alive(worker):
logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
signal.raise_signal(signal.SIGABRT)
time.sleep(gap_time)
class WorkerGroup:
"""A group of workers"""
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
self._is_init_with_detached_workers = True if resource_pool is None else False
......@@ -136,14 +131,10 @@ class WorkerGroup:
def world_size(self):
return len(self._workers)
# execute_all_async and execute_rank_zero_async should be implemented by RayWorkerGroup, TorchRPCWorkerGroup,
# MegatronWorkerGroup, XperfWorkerGroup should skip
def _bind_worker_method(self, user_defined_cls, func_generator):
"""
Bind the worker method to the WorkerGroup
"""
for method_name in dir(user_defined_cls):
try:
method = getattr(user_defined_cls, method_name)
......
......@@ -13,27 +13,28 @@
# limitations under the License.
import os
import random
import re
import string
import time
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch
import ray
from ray.actor import ActorHandle
from ray.experimental.state.api import get_actor
from ray.util import list_named_actors
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy
from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
from verl.single_controller.base.decorator import MAGIC_ATTR
from ..base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
from ..base.decorator import MAGIC_ATTR
__all__ = ["Worker"]
def get_random_string(length: int) -> str:
import random
import string
letters_digits = string.ascii_letters + string.digits
return "".join(random.choice(letters_digits) for _ in range(length))
......@@ -50,6 +51,27 @@ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, block
return func
def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
"""
Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK
to be consistent across nodes when resume from checkpoint.
With this function, if there's only one resource pool and there's no node change, RANK should be consistent
across nodes in multiple ray jobs, even if the whole ray cluster is restarted.
"""
node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
pg_ip = {}
for pg in pgs:
specs = ray._private.state.state.placement_group_table(pg.id)
# all bunles should be on the same node
node_id = specs["bundles_to_node_id"][0]
pg_ip[pg.id] = node_ip[node_id]
return sorted(pgs, key=lambda pg: pg_ip[pg.id])
class RayResourcePool(ResourcePool):
def __init__(
self,
......@@ -57,7 +79,7 @@ class RayResourcePool(ResourcePool):
use_gpu: bool = True,
name_prefix: str = "",
max_colocate_count: int = 5,
detached=False,
detached: bool = False,
) -> None:
super().__init__(process_on_nodes, max_colocate_count)
self.use_gpu = use_gpu
......@@ -66,7 +88,7 @@ class RayResourcePool(ResourcePool):
self.pgs = None
self.detached = detached
def get_placement_groups(self, strategy="STRICT_PACK", name=None):
def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str] = None) -> List[PlacementGroup]:
if self.pgs is not None:
return self.pgs
......@@ -97,7 +119,7 @@ class RayResourcePool(ResourcePool):
def extract_pg_from_exist(
resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool
) -> List:
) -> List[PlacementGroup]:
src_pgs = [
pg
for role_name, resource_pool in resource_pools.items()
......@@ -151,7 +173,12 @@ class RayClassWithInitArgs(ClassWithInitArgs):
self._options.update(options)
def __call__(
self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None
self,
placement_group: PlacementGroup,
placement_group_bundle_idx: int,
use_gpu: bool = True,
num_gpus: int = 1,
sharing_with: Worker = None,
) -> Any:
if sharing_with is not None:
target_node_id = ray.get(sharing_with.get_node_id.remote())
......@@ -188,8 +215,8 @@ class RayWorkerGroup(WorkerGroup):
ray_cls_with_init: RayClassWithInitArgs = None,
bin_pack: bool = True,
name_prefix: str = None,
detached=False,
worker_names=None,
detached: bool = False,
worker_names: List[str] = None,
**kwargs,
) -> None:
super().__init__(resource_pool=resource_pool, **kwargs)
......@@ -210,21 +237,24 @@ class RayWorkerGroup(WorkerGroup):
if ray_cls_with_init is not None:
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
def _is_worker_alive(self, worker: ray.actor.ActorHandle):
def _is_worker_alive(self, worker: ActorHandle) -> bool:
worker_state_dict = get_actor(worker._actor_id.hex())
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
def _init_with_detached_workers(self, worker_names):
def _init_with_detached_workers(self, worker_names: List[str]) -> None:
workers = [ray.get_actor(name=name) for name in worker_names]
self._workers = workers
self._world_size = len(worker_names)
def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached):
def _init_with_resource_pool(
self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, bin_pack: bool, detached: bool
):
use_gpu = resource_pool.use_gpu
strategy = "PACK"
if bin_pack:
strategy = "STRICT_PACK"
pgs = resource_pool.get_placement_groups(strategy=strategy)
world_size = resource_pool.world_size
self._world_size = world_size
......@@ -232,8 +262,8 @@ class RayWorkerGroup(WorkerGroup):
num_gpus = 1 / resource_pool.max_collocate_count
rank = -1
for pg_idx, local_world_size in enumerate(resource_pool.store):
pg = pgs[pg_idx]
local_world_size = resource_pool.store[0]
for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):
assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the "
for local_rank in range(local_world_size):
rank += 1
......@@ -251,8 +281,6 @@ class RayWorkerGroup(WorkerGroup):
env_vars["MASTER_ADDR"] = self._master_addr
env_vars["MASTER_PORT"] = self._master_port
import re
cia_name = type(ray_cls_with_init.cls).__name__
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
......@@ -272,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
if rank == 0:
register_center_actor = None
for _ in range(120):
for _ in range(360):
if f"{self.name_prefix}_register_center" not in list_named_actors():
time.sleep(1)
else:
......
......@@ -19,7 +19,7 @@ import os
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Optional, Tuple
from verl.workers.config import WorkerConfig
from ..workers.config import WorkerConfig
def recursive_post_init(dataclass_obj):
......@@ -36,12 +36,13 @@ class DataConfig:
train_files: str = ""
val_files: str = ""
prompt_key: str = "prompt"
answer_key: str = "answer"
image_key: str = "images"
max_prompt_length: int = 512
max_response_length: int = 512
rollout_batch_size: int = 512
return_raw_input_ids: bool = False
return_raw_prompt: bool = False
system_prompt: str = r"Please reason step by step, and put your final answer within \boxed{}."
val_batch_size: int = -1
system_prompt: Optional[str] = None
shuffle: bool = True
seed: int = 1
max_pixels: int = 4194304
......@@ -52,10 +53,12 @@ class DataConfig:
class AlgorithmConfig:
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"
adv_estimator: str = "grpo"
disable_kl: bool = False
use_kl_loss: bool = False
kl_penalty: str = "kl"
kl_type: str = "fixed"
kl_coef: float = 1e-3
kl_type: str = "fixed"
kl_horizon: float = 0.0
kl_target: float = 0.0
......@@ -67,18 +70,17 @@ class TrainerConfig:
project_name: str = "easy_r1"
experiment_name: str = "demo"
logger: Tuple[str] = ("console", "wandb")
val_generations_to_log_to_wandb: int = 0
nnodes: int = 1
n_gpus_per_node: int = 8
save_freq: int = -1
load_checkpoint_path: Optional[str] = None
critic_warmup: int = 0
val_freq: int = -1
val_before_train: bool = True
val_only: bool = False
test_freq: int = -1
critic_warmup: int = 0
remove_previous_ckpt: bool = False
del_local_ckpt_after_load: bool = False
val_generations_to_log: int = 0
save_freq: int = -1
save_limit: int = -1
save_checkpoint_path: Optional[str] = None
load_checkpoint_path: Optional[str] = None
def post_init(self):
if self.save_checkpoint_path is None:
......@@ -95,6 +97,10 @@ class PPOConfig:
def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length
self.worker.actor.disable_kl = self.algorithm.disable_kl
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
self.worker.actor.kl_penalty = self.algorithm.kl_penalty
self.worker.actor.kl_coef = self.algorithm.kl_coef
def deep_post_init(self):
recursive_post_init(self)
......
# Copyright 2022 The HuggingFace Team
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,50 +18,55 @@ The function implemented in this file should be used by trainer with different d
implement PPO
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import TYPE_CHECKING, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import verl.utils.torch_functional as verl_F
from ..utils import torch_functional as VF
if TYPE_CHECKING:
from verl.trainer.config import AlgorithmConfig
from .config import AlgorithmConfig
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
class KLController(ABC):
@abstractmethod
def update(self, current_kl: float, n_steps: int) -> None: ...
class AdaptiveKLController(KLController):
"""Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf"""
def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
def update(self, current_kl: float, n_steps: int) -> None:
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
class FixedKLController(KLController):
"""Fixed KL controller."""
def __init__(self, kl_coef: float):
self.value = kl_coef
def __init__(self, init_kl_coef: float):
self.value = init_kl_coef
def update(self, current_kl, n_steps):
def update(self, current_kl: float, n_steps: int) -> None:
pass
def get_kl_controller(algorithm_config: "AlgorithmConfig"):
def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
"""Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319"""
if algorithm_config.kl_type == "fixed":
kl_ctrl = FixedKLController(kl_coef=algorithm_config.kl_coef)
kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
elif algorithm_config.kl_type == "adaptive":
assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
kl_ctrl = AdaptiveKLController(
......@@ -70,19 +75,20 @@ def get_kl_controller(algorithm_config: "AlgorithmConfig"):
horizon=algorithm_config.kl_horizon,
)
else:
raise ValueError("Unknown kl_ctrl type")
raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")
return kl_ctrl
@torch.no_grad()
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
eos_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513
Args:
token_level_rewards: `(torch.Tensor)`
......@@ -103,27 +109,26 @@ def compute_gae_advantage_return(
shape: (bs, response_length)
"""
with torch.no_grad():
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, eos_mask)
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = (advantages + values) * eos_mask
advantages = VF.masked_whiten(advantages, eos_mask) * eos_mask
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@torch.no_grad()
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6
):
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
......@@ -133,6 +138,50 @@ def compute_grpo_outcome_advantage(
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean, id2std = {}, {}
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
@torch.no_grad()
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
......@@ -144,31 +193,33 @@ def compute_grpo_outcome_advantage(
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}.")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
response_num - 1
)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
@torch.no_grad()
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
......@@ -184,26 +235,24 @@ def compute_reinforce_plus_plus_outcome_advantage(
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]
advantages = verl_F.masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]
advantages = VF.masked_whiten(returns, eos_mask)
advantages *= eos_mask
returns *= eos_mask
return advantages, returns
@torch.no_grad()
def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor
):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
......@@ -225,23 +274,31 @@ def compute_remax_outcome_advantage(
"""
response_length = token_level_rewards.shape[-1]
# scores = token_level_rewards.sum(dim=-1)
with torch.no_grad():
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) * eos_mask
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
return advantages, returns
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
kl = old_log_prob - ref_log_prob
def compute_rewards(
token_level_scores: torch.Tensor,
log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
kl_ratio: float,
) -> torch.Tensor:
kl = log_probs - ref_log_probs
return token_level_scores - kl * kl_ratio
def compute_policy_loss(
old_log_prob, log_prob, advantages, eos_mask, cliprange
old_log_probs: torch.Tensor,
log_probs: torch.Tensor,
advantages: torch.Tensor,
eos_mask: torch.Tensor,
cliprange: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
"""Compute the policy loss.
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
Args:
old_log_prob: `(torch.Tensor)`
......@@ -260,95 +317,88 @@ def compute_policy_loss(
policy gradient loss computed via PPO
pg_clipfrac: (float)
a float number indicating the fraction of policy gradient loss being clipped
"""
negative_approx_kl = log_prob - old_log_prob
negative_approx_kl = log_probs - old_log_probs
# clamp the ratio before exp to avoid nan
# see: https://github.com/pytorch/pytorch/issues/10729
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
clipped_ratio = torch.exp(torch.clamp(negative_approx_kl, np.log(1.0 - cliprange), np.log(1.0 + cliprange)))
ppo_kl = VF.masked_mean(-negative_approx_kl, eos_mask)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_losses2 = -advantages * clipped_ratio
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
pg_loss = VF.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = VF.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl
def compute_entropy_loss(logits, eos_mask):
"""Compute Categorical entropy loss
Args:
logits: `(torch.Tensor)`
shape: (bs, response_length, vocab_size)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
return entropy_loss
def compute_value_loss(
vpreds: torch.Tensor,
returns: torch.Tensor,
values: torch.Tensor,
eos_mask: torch.Tensor,
cliprange_value: float,
) -> Tuple[torch.Tensor, float]:
"""Compute the value loss.
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Copied from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
Args:
vpreds (`torch.FloatTensor`):
Predicted values of the value head, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
returns: (`torch.FloatTensor`):
Ground truth returns, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange_value: (float)
The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
Returns:
vf_loss: a scalar (`torch.FloatTensor`):
value function loss
vf_clipfrac: a float
The ratio of vf being clipped
"""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = torch.square(vpreds - returns)
vf_losses2 = torch.square(vpredclipped - returns)
vf_loss = 0.5 * VF.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = VF.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
return vf_loss, vf_clipfrac
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.Tensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
def kl_penalty(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
"""Compute KL divergence given log_probs and ref_log_probs.
Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
Args:
logprob:
ref_logprob:
log_probs: torch.Tensor
ref_log_probs: torch.Tensor
Returns:
kl_div: torch.Tensor
"""
log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
if kl_penalty == "kl":
return logprob - ref_logprob
return log_probs - ref_log_probs
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
return (log_probs - ref_log_probs).abs()
if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()
return 0.5 * (log_probs - ref_log_probs).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
# URL http://joschu.net/blog/kl-approx.html
if kl_penalty == "low_var_kl":
kl = ref_logprob - logprob
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
kl = ref_log_probs - log_probs
kld = (kl.exp() - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
raise NotImplementedError
raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")
......@@ -16,81 +16,100 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by o
"""
import json
import torch
import ray
from omegaconf import OmegaConf
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.config import PPOConfig
from verl.trainer.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
from verl.utils import get_processor, get_tokenizer
from verl.workers.fsdp_workers import FSDPWorker
from verl.workers.reward import CustomRewardManager
from ..single_controller.ray import RayWorkerGroup
from ..utils.tokenizer import get_processor, get_tokenizer
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import CustomRewardManager
from .config import PPOConfig
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
# please make sure main_task is not scheduled on head
@ray.remote(num_cpus=1)
class Runner:
"""A runner for RL training."""
def run(self, config: PPOConfig):
# print config
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer
tokenizer = get_tokenizer(
config.worker.actor.model.model_path,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
processor = get_processor(
config.worker.actor.model.model_path,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
# define worker classes
ray_worker_group_cls = RayWorkerGroup
role_worker_mapping = {
Role.ActorRollout: ray.remote(FSDPWorker),
Role.Critic: ray.remote(FSDPWorker),
Role.RefPolicy: ray.remote(FSDPWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
def main():
cli_args = OmegaConf.from_cli()
file_config = OmegaConf.load(cli_args.config)
del cli_args.config
default_config = OmegaConf.structured(PPOConfig())
ppo_config = OmegaConf.merge(default_config, file_config, cli_args)
if hasattr(cli_args, "config"):
config_path = cli_args.pop("config", None)
file_config = OmegaConf.load(config_path)
default_config = OmegaConf.merge(default_config, file_config)
ppo_config = OmegaConf.merge(default_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config)
# this is for local ray cluster
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.get(main_task.remote(ppo_config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config: PPOConfig):
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer
tokenizer = get_tokenizer(config.worker.actor.model.model_path)
processor = get_processor(config.worker.actor.model.model_path, use_fast=True)
# define worker classes
ray_worker_group_cls = RayWorkerGroup
role_worker_mapping = {
Role.ActorRollout: ray.remote(FSDPWorker),
Role.Critic: ray.remote(FSDPWorker),
Role.RefPolicy: ray.remote(FSDPWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
val_reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
# for rocm
if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(),
ignore_reinit_error=True,
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
else:
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
runner = Runner.remote()
ray.get(runner.run.remote(ppo_config))
if __name__ == "__main__":
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List
import numpy as np
import torch
from ..protocol import DataProto
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
return {key: np.mean(value) for key, value in metrics.items()}
def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str, Any]:
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
max_response_length = batch.batch["responses"].size(-1)
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float()
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
num_response_tokens = torch.sum(batch.batch["response_mask"]).item()
num_overall_tokens = sum(batch.meta_info["global_token_num"])
num_tokens_of_section = {
**dict.fromkeys(["gen", "reward"], num_response_tokens),
**dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens),
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
total_num_tokens = sum(batch.meta_info["global_token_num"])
time = timing_raw["step"]
return {
"perf/total_num_tokens": total_num_tokens,
"perf/time_per_step": time,
"perf/throughput": total_num_tokens / (time * n_gpus),
}
This diff is collapsed.
......@@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import get_processor, get_tokenizer
__all__ = ["get_processor", "get_tokenizer"]
......@@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt
__all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"]
......@@ -11,20 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import torch.distributed
import torch.distributed as dist
from filelock import FileLock
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import PreTrainedTokenizer, ProcessorMixin
class BaseCheckpointManager:
CHECKPOINT_TRACKER = "latest_global_step.txt"
class BaseCheckpointManager(ABC):
"""
A checkpoint manager that saves and loads
- model
......@@ -44,42 +51,27 @@ class BaseCheckpointManager:
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin
processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
):
self.previous_global_step = None
self.previous_save_local_path = None
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.processor = processor
self.processing_class = processing_class
assert isinstance(self.model, FSDP)
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
@abstractmethod
def load_checkpoint(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod
def save_checkpoint(self, *args, **kwargs):
raise NotImplementedError
def remove_previous_save_local_path(self):
if not self.previous_save_local_path:
return
abs_path = os.path.abspath(self.previous_save_local_path)
print(f"Checkpoint manager remove previous save local path: {abs_path}")
if not os.path.exists(abs_path):
return
# remove previous local_path
shutil.rmtree(abs_path, ignore_errors=True)
@staticmethod
def local_mkdir(path):
def local_mkdir(path: str) -> str:
if not os.path.isabs(path):
working_dir = os.getcwd()
path = os.path.join(working_dir, path)
......@@ -89,18 +81,16 @@ class BaseCheckpointManager:
lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
try:
with FileLock(lock_path, timeout=60): # Add timeout
# make a new dir
with FileLock(lock_path, timeout=60):
os.makedirs(path, exist_ok=True)
except Exception as e:
print(f"Warning: Failed to acquire lock for {path}: {e}")
# Even if the lock is not acquired, try to create the directory
os.makedirs(path, exist_ok=True)
os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory
return path
@staticmethod
def get_rng_state():
def get_rng_state() -> Dict[str, Any]:
rng_state = {
"cpu": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state(),
......@@ -110,14 +100,14 @@ class BaseCheckpointManager:
return rng_state
@staticmethod
def load_rng_state(rng_state):
def load_rng_state(rng_state: Dict[str, Any]):
torch.set_rng_state(rng_state["cpu"])
torch.cuda.set_rng_state(rng_state["cuda"])
np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["random"])
def find_latest_ckpt_path(path, directory_format="global_step_{}"):
def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]:
if path is None:
return None
......@@ -128,6 +118,7 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
with open(tracker_file, "rb") as f:
iteration = int(f.read().decode())
ckpt_path = os.path.join(path, directory_format.format(iteration))
if not os.path.exists(ckpt_path):
print("Checkpoint does not exist: %s", ckpt_path)
......@@ -137,8 +128,33 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
return ckpt_path
def get_checkpoint_tracker_filename(root_path: str):
def get_checkpoint_tracker_filename(root_path: str) -> str:
"""
Tracker file rescords the latest chckpoint during training to restart from.
"""
return os.path.join(root_path, "latest_checkpointed_iteration.txt")
return os.path.join(root_path, CHECKPOINT_TRACKER)
def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"):
"""
Remove the obsolete checkpoints that exceed the save_limit.
"""
if save_limit <= 0:
return
if not os.path.exists(path):
return
pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)")
ckpt_folders = []
for folder in os.listdir(path):
if match := re.match(pattern, folder):
step = int(match.group(1))
if step < global_step:
ckpt_folders.append((step, folder))
ckpt_folders.sort(reverse=True)
for _, folder in ckpt_folders[save_limit - 1 :]:
folder_path = os.path.join(path, folder)
shutil.rmtree(folder_path, ignore_errors=True)
print(f"Removed obsolete checkpoint: {folder_path}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment