Commit ff7fb65e authored by chenych's avatar chenych
Browse files

Update

parent c132cbcb
......@@ -66,7 +66,7 @@ if __name__ == "__main__":
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
assert mesh_dim_names in (("fsdp",),), f"Unsupported mesh_dim_names {mesh_dim_names}"
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}"
if "tp" in mesh_dim_names:
# fsdp * tp
......@@ -106,9 +106,10 @@ if __name__ == "__main__":
if isinstance(tensor, DTensor):
state_dict[key].append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements)
# replicated placement at dp dimension can be discarded
if mesh_dim_names[0] == "dp":
# replicated placement at ddp dimension can be discarded
if mesh_dim_names[0] == "ddp":
placements = placements[1:]
if key not in param_placements:
param_placements[key] = placements
else:
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Supported models using HF Rmpad
# TODO(sgm): HF may supported more than listed here, we should add more after testing
from transformers import GemmaConfig, LlamaConfig, MistralConfig, Qwen2Config
_REOVEPAD_MODELS = {"llama": LlamaConfig, "mistral": MistralConfig, "gemma": GemmaConfig, "qwen2": Qwen2Config}
def check_model_support_rmpad(model_type: str):
assert isinstance(model_type, str)
if model_type not in _REOVEPAD_MODELS.keys():
raise ValueError(
f"Model architecture {model_type} is not supported for now. "
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}."
f"Please set `use_remove_padding=False` in the model config."
)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Optional, Tuple
import torch
if sys.version_info >= (3, 11):
pass
else:
pass
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.utils import logging
from verl.utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_world_size,
)
logger = logging.get_logger(__name__)
def llama_flash_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
adapt from transformers 4.47.1
"""
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# trade off: repeat first and then all to all
# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2) # full seq length
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
full_q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apply monkey-patch function to models
"""
#### Open Source Models
#### transformers version < 4.48
import importlib.metadata
from functools import lru_cache
from packaging import version
from transformers import PretrainedConfig
def apply_monkey_patch_to_llama():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from verl.models.transformers.llama import llama_flash_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward
def apply_monkey_patch_to_qwen2():
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
from verl.models.transformers.qwen2 import qwen2_flash_attn_forward
Qwen2FlashAttention2.forward = qwen2_flash_attn_forward
_PATCH_NAME_TO_FUNC = {
"llama": apply_monkey_patch_to_llama,
"qwen2": apply_monkey_patch_to_qwen2,
}
def apply_monkey_patch(config: PretrainedConfig, verbose=True):
if not is_transformers_version_in_range("4.45.0", "4.47.1"):
raise AssertionError(
"The installed `transformers` version doesn't support ulysses patch. "
"Please install a version between 4.45.0 and 4.47.1 to use this ulysses feature."
)
success_apply_monkey_patch = False
if config.model_type in _PATCH_NAME_TO_FUNC:
_PATCH_NAME_TO_FUNC[config.model_type]()
success_apply_monkey_patch = True
if success_apply_monkey_patch and verbose:
print(f"Applying monkey patch to model {config.model_type}")
elif not success_apply_monkey_patch:
raise NotImplementedError(
f"Ulysses for model {config.model_type} is not implemented, \
please set `ulysses_sequence_parallel_size=1`"
)
return success_apply_monkey_patch
@lru_cache
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
try:
# Get the installed version of the transformers library
transformers_version = importlib.metadata.version("transformers")
except importlib.metadata.PackageNotFoundError:
raise ModuleNotFoundError("The `transformers` package is not installed.")
# Check if the version is within the specified range
return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.utils import logging
from verl.utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_world_size,
)
logger = logging.get_logger(__name__)
def qwen2_flash_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
########## AlltoAll for Ulysses ##########
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
# (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
full_q_len = query_states.size(2) # full seq length
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
full_q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
# use full_q_len to reshape
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
from typing import Optional
import torch
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor
def get_rope_index(
processor: Qwen2_5_VLProcessor,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
video_grid_thw: Optional[torch.Tensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
"""
spatial_merge_size = processor.image_processor.merge_size
tokens_per_second = 2
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
image_index, video_index = 0, 0
input_ids = input_ids[attention_mask == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
else:
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
return position_ids
......@@ -122,7 +122,7 @@ def batch_collate(features: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
return batch_features
def fold_batch_dim(data: "DataProto", new_batch_size):
def fold_batch_dim(data: "DataProto", new_batch_size: int):
"""
Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]
"""
......@@ -158,8 +158,8 @@ def collate_fn(data_items: list["DataProtoItem"]):
@dataclass
class DataProtoItem:
batch: Optional[TensorDict] = None
non_tensor_batch: Dict = field(default_factory=dict)
meta_info: Dict = field(default_factory=dict)
non_tensor_batch: Dict[str, NDArray] = field(default_factory=dict)
meta_info: Dict[str, Any] = field(default_factory=dict)
@dataclass
......@@ -172,13 +172,13 @@ class DataProto:
"""
batch: Optional[TensorDict] = None
non_tensor_batch: Dict[str, Any] = field(default_factory=dict)
non_tensor_batch: Dict[str, NDArray] = field(default_factory=dict)
meta_info: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
self.check_consistency() # perform necessary checking
def __len__(self):
def __len__(self) -> int:
if self.batch is not None:
return self.batch.batch_size[0]
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
......@@ -187,43 +187,41 @@ class DataProto:
else:
return 0
def __getitem__(self, item):
def __getitem__(self, item: Union[int, slice]) -> Union["DataProto", "DataProtoItem"]:
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self):
def __getstate__(self) -> Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]:
buffer = io.BytesIO()
if self.batch is not None:
self.batch = self.batch.contiguous()
self.batch = self.batch.consolidate()
self.batch: TensorDict = self.batch.contiguous()
self.batch: TensorDict = self.batch.consolidate()
torch.save(self.batch, buffer)
buffer_bytes = buffer.getvalue()
return buffer_bytes, self.non_tensor_batch, self.meta_info
def __setstate__(self, data):
def __setstate__(self, data: Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]) -> None:
batch_deserialized_bytes, non_tensor_batch, meta_info = data
batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
batch = torch.load(
batch_deserialized, weights_only=False, map_location="cpu" if not torch.cuda.is_available() else None
)
batch_deserialized = io.BytesIO(batch_deserialized_bytes)
batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu")
self.batch = batch
self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_info
def save_to_disk(self, filepath):
def save_to_disk(self, filepath: str) -> None:
with open(filepath, "wb") as f:
pickle.dump(self, f)
@staticmethod
def load_from_disk(filepath) -> "DataProto":
def load_from_disk(filepath: str) -> "DataProto":
with open(filepath, "rb") as f:
data = pickle.load(f)
return data
def print_size(self, prefix=""):
def print_size(self, prefix: str = "") -> None:
size_of_tensordict = 0
for tensor in self.batch.values():
if isinstance(tensor, torch.Tensor):
......@@ -255,7 +253,11 @@ class DataProto:
assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}."
@classmethod
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, NDArray]], meta_info=None):
def from_single_dict(
cls,
data: Dict[str, Union[torch.Tensor, NDArray]],
meta_info: Optional[Dict[str, Any]] = None,
) -> "DataProto":
tensors = {}
non_tensors = {}
for key, value in data.items():
......@@ -269,7 +271,13 @@ class DataProto:
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
@classmethod
def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1):
def from_dict(
cls,
tensors: Dict[str, torch.Tensor],
non_tensors: Dict[str, NDArray] = None,
meta_info: Optional[Dict[str, Any]] = None,
num_batch_dims: int = 1,
) -> "DataProto":
"""Create a DataProto from a dict of tensors. This assumes that
1. All the tensor in tensors have the same dim0
2. Only dim0 is the batch dim
......@@ -293,13 +301,14 @@ class DataProto:
else:
current_batch = tensor.shape[:num_batch_dims]
assert batch_size == current_batch, (
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}"
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. "
f"Got {pivot_key} has {batch_size}, {key} has {current_batch}"
)
tensor_dict = TensorDict(source=tensors, batch_size=batch_size)
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
def to(self, device) -> "DataProto":
def to(self, device: torch.device) -> "DataProto":
"""move the batch to device
Args:
......@@ -314,7 +323,13 @@ class DataProto:
return self
def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto":
def select(
self,
batch_keys: Optional[List[str]] = None,
non_tensor_batch_keys: Optional[List[str]] = None,
meta_info_keys: Optional[List[str]] = None,
deepcopy: bool = False,
) -> "DataProto":
"""Select a subset of the DataProto via batch_keys and meta_info_keys
Args:
......@@ -332,7 +347,7 @@ class DataProto:
sub_batch = self.batch
if non_tensor_batch_keys is not None:
non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
non_tensor_batch = {k: v for k, v in self.non_tensor_batch.items() if k in non_tensor_batch_keys}
else:
non_tensor_batch = self.non_tensor_batch
......@@ -340,7 +355,7 @@ class DataProto:
non_tensor_batch = copy.deepcopy(non_tensor_batch)
if meta_info_keys is not None:
sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
sub_meta_info = {k: v for k, v in self.meta_info.items() if k in meta_info_keys}
else:
sub_meta_info = self.meta_info
......@@ -349,7 +364,12 @@ class DataProto:
return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto":
def pop(
self,
batch_keys: Optional[List[str]] = None,
non_tensor_batch_keys: Optional[List[str]] = None,
meta_info_keys: Optional[List[str]] = None,
) -> "DataProto":
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
Args:
......@@ -377,7 +397,9 @@ class DataProto:
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
def rename(self, old_keys=None, new_keys=None) -> "DataProto":
def rename(
self, old_keys: Optional[Union[str, List[str]]] = None, new_keys: Optional[Union[str, List[str]]] = None
) -> "DataProto":
"""
Note that this function only rename the key in the batch
"""
......@@ -422,7 +444,9 @@ class DataProto:
self.meta_info = union_two_dict(self.meta_info, other.meta_info)
return self
def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
def make_iterator(
self, mini_batch_size: int, epochs: int, seed: int = None, dataloader_kwargs: Dict[str, Any] = None
):
"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
......@@ -509,9 +533,7 @@ class DataProto:
Returns:
DataProto: concatenated DataProto
"""
batch_lst = []
for batch in data:
batch_lst.append(batch.batch)
batch_lst = [batch.batch for batch in data]
if batch_lst[0] is not None:
new_batch = torch.cat(batch_lst, dim=0)
else:
......@@ -523,7 +545,7 @@ class DataProto:
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
def reorder(self, indices):
def reorder(self, indices: torch.Tensor) -> None:
"""
Note that this operation is in-place
"""
......@@ -531,7 +553,7 @@ class DataProto:
self.batch = self.batch[indices]
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
def repeat(self, repeat_times=2, interleave=True):
def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto":
"""
Repeat the batch data a specified number of times.
......
......@@ -42,7 +42,7 @@ class DataConfig:
max_response_length: int = 512
rollout_batch_size: int = 512
val_batch_size: int = -1
system_prompt: Optional[str] = None
format_prompt: Optional[str] = None
shuffle: bool = True
seed: int = 1
max_pixels: int = 4194304
......
......@@ -34,15 +34,22 @@ if TYPE_CHECKING:
class KLController(ABC):
kl_coef: float
"""KL coefficient."""
@abstractmethod
def update(self, current_kl: float, n_steps: int) -> None: ...
def update(self, current_kl: float, n_steps: int) -> None:
"""Update kl_coef according to current KL."""
...
class AdaptiveKLController(KLController):
"""Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf"""
"""Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf
Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L54"""
def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
self.value = init_kl_coef
self.kl_coef = init_kl_coef
self.target = target_kl
self.horizon = horizon
......@@ -50,14 +57,16 @@ class AdaptiveKLController(KLController):
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
self.kl_coef *= mult
class FixedKLController(KLController):
"""Fixed KL controller."""
"""Fixed KL controller.
Copeid from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L72"""
def __init__(self, init_kl_coef: float):
self.value = init_kl_coef
self.kl_coef = init_kl_coef
def update(self, current_kl: float, n_steps: int) -> None:
pass
......@@ -84,7 +93,7 @@ def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
eos_mask: torch.Tensor,
response_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -95,8 +104,8 @@ def compute_gae_advantage_return(
shape: (bs, response_length)
values: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
response_mask: `(torch.Tensor)`
shape: (bs, response_length). The token after eos tokens have mask zero.
gamma: `(float)`
discounted factor used in RL
lam: `(float)`
......@@ -105,7 +114,7 @@ def compute_gae_advantage_return(
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
......@@ -119,32 +128,33 @@ def compute_gae_advantage_return(
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = (advantages + values) * eos_mask
advantages = VF.masked_whiten(advantages, eos_mask) * eos_mask
returns = advantages + values
advantages = VF.masked_whiten(advantages, response_mask)
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
@torch.no_grad()
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean, id2std = {}, {}
......@@ -154,86 +164,78 @@ def compute_grpo_outcome_advantage(
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
id2std[idx] = torch.std(torch.tensor(id2score[idx]))
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
returns = scores.unsqueeze(-1) * response_mask
return returns, returns
@torch.no_grad()
def compute_rloo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2sum = {}
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}.")
id2sum[idx] = torch.sum(torch.tensor(id2score[idx]))
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (
response_num - 1
)
sample_num = len(id2score[index[i]])
assert sample_num > 1, "RLOO needs rollout.n > 1."
baseline = (id2sum[index[i]] - scores[i]) / (sample_num - 1)
scores[i] = scores[i] - baseline
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
returns = scores.unsqueeze(-1) * response_mask
return returns, returns
@torch.no_grad()
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
returns = torch.zeros_like(token_level_rewards)
running_return = 0
......@@ -241,17 +243,15 @@ def compute_reinforce_plus_plus_outcome_advantage(
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]
running_return = running_return * response_mask[:, t]
advantages = VF.masked_whiten(returns, eos_mask)
advantages *= eos_mask
returns *= eos_mask
advantages = VF.masked_whiten(returns, response_mask)
return advantages, returns
@torch.no_grad()
def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for ReMax, operating only on Outcome reward
......@@ -263,20 +263,19 @@ def compute_remax_outcome_advantage(
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
# scores = token_level_rewards.sum(dim=-1)
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) * eos_mask
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
return advantages, returns
scores = token_level_rewards.sum(dim=-1) - reward_baselines
returns = scores.unsqueeze(-1) * response_mask
return returns, returns
def compute_rewards(
......@@ -293,9 +292,11 @@ def compute_policy_loss(
old_log_probs: torch.Tensor,
log_probs: torch.Tensor,
advantages: torch.Tensor,
eos_mask: torch.Tensor,
cliprange: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
response_mask: torch.Tensor,
clip_ratio_low: float,
clip_ratio_high: float,
clip_ratio_dual: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the policy loss.
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
......@@ -307,42 +308,61 @@ def compute_policy_loss(
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange: (float)
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
clip_ratio_low: (float)
The lower clip range used in PPO. See https://arxiv.org/abs/1707.06347
clip_ratio_high: (float)
The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476
clip_ratio_dual: (float)
The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729
Returns:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via PPO
pg_clipfrac: (float)
a float number indicating the fraction of policy gradient loss being clipped
pg_clipfrac_higher: (float)
a float number indicating the fraction of policy gradient loss being clipped to a higher value
pg_clipfrac_lower: (float)
a float number indicating the fraction of policy gradient loss being clipped to a lower value
ppo_kl: (float)
a float number indicating the mean KL divergence between the old policy and the new policy
"""
negative_approx_kl = log_probs - old_log_probs
# clamp the ratio before exp to avoid nan
# see: https://github.com/pytorch/pytorch/issues/10729
ratio = torch.exp(negative_approx_kl)
clipped_ratio = torch.exp(torch.clamp(negative_approx_kl, np.log(1.0 - cliprange), np.log(1.0 + cliprange)))
ppo_kl = VF.masked_mean(-negative_approx_kl, eos_mask)
clipped_ratio = torch.exp(
torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high))
)
pg_loss = -advantages * ratio
pg_loss2 = -advantages * clipped_ratio
pg_loss3 = -advantages * clip_ratio_dual
pg_losses = -advantages * ratio
pg_losses2 = -advantages * clipped_ratio
clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2) # clip if pg_loss < pg_loss2
pg_clipfrac_higher = (pg_loss < pg_loss2).float()
clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3) # clip if pg_loss > pg_loss3 and adv < 0
final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher)
pg_clipfrac_lower = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float()
pg_loss = VF.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = VF.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl
final_pg_loss = VF.masked_mean(final_pg_loss, response_mask)
pg_clipfrac_higher = VF.masked_mean(pg_clipfrac_higher, response_mask)
pg_clipfrac_lower = VF.masked_mean(pg_clipfrac_lower, response_mask)
ppo_kl = VF.masked_mean(-negative_approx_kl, response_mask)
return final_pg_loss, pg_clipfrac_higher, pg_clipfrac_lower, ppo_kl
def compute_value_loss(
vpreds: torch.Tensor,
returns: torch.Tensor,
values: torch.Tensor,
eos_mask: torch.Tensor,
action_mask: torch.Tensor,
cliprange_value: float,
) -> Tuple[torch.Tensor, float]:
"""Compute the value loss.
Copied from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
Args:
vpreds (`torch.FloatTensor`):
......@@ -351,7 +371,7 @@ def compute_value_loss(
Ground truth returns, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
eos_mask: `(torch.Tensor)`
action_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange_value: (float)
The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
......@@ -361,25 +381,29 @@ def compute_value_loss(
value function loss
vf_clipfrac: a float
The ratio of vf being clipped
"""
vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = torch.square(vpreds - returns)
vf_losses2 = torch.square(vpredclipped - returns)
vf_loss = 0.5 * VF.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = VF.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
vf_loss1 = torch.square(vpreds - returns)
vf_loss2 = torch.square(vpredclipped - returns)
vf_loss = 0.5 * VF.masked_mean(torch.max(vf_loss1, vf_loss2), action_mask) # clip if vf_loss1 < vf_loss2
vf_clipfrac = VF.masked_mean((vf_loss1 < vf_loss2).float(), action_mask)
return vf_loss, vf_clipfrac
def kl_penalty(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
def compute_kl(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor:
"""Compute KL divergence given log_probs and ref_log_probs.
Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
Args:
log_probs: torch.Tensor
ref_log_probs: torch.Tensor
kl_penalty: str
Returns:
kl_div: torch.Tensor
"""
log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
if kl_penalty == "kl":
......
......@@ -69,8 +69,8 @@ class Runner:
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, compute_score=config.worker.reward.compute_score)
reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward)
trainer = RayPPOTrainer(
config=config,
......@@ -98,9 +98,8 @@ def main():
ppo_config = OmegaConf.merge(default_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config)
# this is for local ray cluster
if not ray.is_initialized():
# for rocm
# this is for local ray cluster
if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(),
ignore_reinit_error=True,
......@@ -108,6 +107,7 @@ def main():
else:
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
runner = Runner.remote()
ray.get(runner.run.remote(ppo_config))
......
......@@ -23,7 +23,7 @@ from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import numpy as np
import ray
......@@ -50,9 +50,6 @@ from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
WorkerType = Type[Worker]
class Role(IntEnum):
"""
To create more roles dynamically, you can subclass Role and add new members
......@@ -132,25 +129,19 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal
# compute kl between ref_policy and current policy
if "ref_log_probs" in data.batch.keys():
kld = core_algos.kl_penalty(
data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty
) # (batch_size, response_length)
kld = kld * response_mask
beta = kl_ctrl.value
kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
kld = kld * response_mask # (batch_size, response_length)
else:
beta = 0
kld = torch.zeros_like(response_mask, dtype=torch.float32)
token_level_rewards = token_level_scores - beta * kld
data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld
current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
metrics = {"critic/kl": current_kl, "critic/kl_coef": kl_ctrl.kl_coef}
# According to https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L880
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch["token_level_rewards"] = token_level_rewards
metrics = {"critic/kl": current_kl, "critic/kl_coef": beta}
return data, metrics
......@@ -161,25 +152,21 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma:
if adv_estimator == AdvantageEstimator.GAE:
values = data.batch["values"]
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=token_level_rewards, values=values, eos_mask=response_mask, gamma=gamma, lam=lam
token_level_rewards, values, response_mask, gamma, lam
)
elif adv_estimator == AdvantageEstimator.GRPO:
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index
)
advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards, response_mask, index)
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma
token_level_rewards, response_mask, gamma
)
elif adv_estimator == AdvantageEstimator.REMAX:
reward_baselines = data.batch["reward_baselines"]
advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards=token_level_rewards, reward_baselines=reward_baselines, eos_mask=response_mask
token_level_rewards, reward_baselines, response_mask
)
elif adv_estimator == AdvantageEstimator.RLOO:
advantages, returns = core_algos.compute_rloo_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index
)
advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards, response_mask, index)
else:
raise NotImplementedError
......@@ -206,11 +193,11 @@ class RayPPOTrainer:
config: PPOConfig,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin],
role_worker_mapping: dict[Role, WorkerType],
role_worker_mapping: dict[Role, Type[Worker]],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn: Callable = None,
val_reward_fn: Callable = None,
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
val_reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
):
self.tokenizer = tokenizer
self.processor = processor
......@@ -249,10 +236,31 @@ class RayPPOTrainer:
raise NotImplementedError(f"Unknown advantage estimator: {config.algorithm.adv_estimator}.")
if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by global batch size.")
raise ValueError("Rollout batch size must be divisible by actor global batch size.")
if (
config.data.rollout_batch_size * config.worker.rollout.n
) % config.worker.actor.micro_batch_size_per_device_for_experience != 0:
raise ValueError(
"Rollout batch size * rollout.n must be divisible by actor micro batch size for experience."
)
if self.use_critic and config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by global batch size.")
if self.use_critic:
if config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0:
raise ValueError("Rollout batch size must be divisible by critic global batch size.")
if (
config.data.rollout_batch_size * config.worker.rollout.n
) % config.worker.critic.micro_batch_size_per_device_for_experience != 0:
raise ValueError(
"Rollout batch size * rollout.n must be divisible by critic micro batch size for experience."
)
if (
config.algorithm.adv_estimator in (AdvantageEstimator.GRPO, AdvantageEstimator.RLOO)
and config.worker.rollout.n == 1
):
raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.")
self._create_dataloader()
......@@ -266,7 +274,7 @@ class RayPPOTrainer:
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
system_prompt=self.config.data.system_prompt,
format_prompt=self.config.data.format_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
......@@ -297,7 +305,7 @@ class RayPPOTrainer:
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
system_prompt=self.config.data.system_prompt,
format_prompt=self.config.data.format_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
......@@ -328,13 +336,15 @@ class RayPPOTrainer:
self.config.worker.critic.optim.training_steps = training_steps
print(f"Total training steps: {self.training_steps}")
def _maybe_log_val_generations(self, inputs: List[str], outputs: List[str], scores: List[float]) -> None:
def _maybe_log_val_generations(
self, inputs: List[str], outputs: List[str], labels: List[str], scores: List[float]
) -> None:
"""Log a table of validation samples"""
if self.config.trainer.val_generations_to_log <= 0:
return
# Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores))
samples = list(zip(inputs, outputs, labels, scores))
samples.sort(key=lambda x: x[0]) # Sort by input text
# Use fixed random seed for deterministic shuffling
......@@ -347,10 +357,10 @@ class RayPPOTrainer:
def _validate(self) -> Dict[str, Any]:
reward_tensor_lst = []
# Lists to collect samples for the table
sample_inputs, sample_outputs, sample_scores = [], [], []
sample_inputs, sample_outputs, sample_labels, sample_scores = [], [], [], []
reward_metrics_lst = defaultdict(list)
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
for batch_dict in self.val_dataloader:
test_batch = DataProto.from_single_dict(batch_dict)
# Store original inputs
input_ids = test_batch.batch["input_ids"]
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
......@@ -377,7 +387,7 @@ class RayPPOTrainer:
output_ids = test_output_gen_batch.batch["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
sample_labels.extend(test_batch.non_tensor_batch["ground_truth"].tolist())
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
......@@ -391,7 +401,7 @@ class RayPPOTrainer:
for key, value in reward_metrics.items():
reward_metrics_lst[key].extend(value)
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores)
reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item()
val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()}
return {"val/reward_score": reward_score, **val_reward_metrics}
......@@ -576,7 +586,8 @@ class RayPPOTrainer:
if self.config.algorithm.adv_estimator == "remax":
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["temperature"] = 0.0
gen_baseline_batch.meta_info["temperature"] = 0
gen_baseline_batch.meta_info["n"] = 1
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
......@@ -634,7 +645,8 @@ class RayPPOTrainer:
with _timer("adv", timing_raw):
# apply kl penalty if available
if not self.config.algorithm.use_kl_loss and self.use_reference_policy: # apply kl penalty to reward
if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
# apply kl penalty to reward
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
......
......@@ -90,7 +90,7 @@ class RLHFDataset(Dataset, ImageProcessMixin):
image_key: str = "images",
max_prompt_length: int = 1024,
truncation: str = "error",
system_prompt: str = None,
format_prompt: str = None,
max_pixels: int = None,
min_pixels: int = None,
):
......@@ -101,7 +101,7 @@ class RLHFDataset(Dataset, ImageProcessMixin):
self.image_key = image_key
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.system_prompt = system_prompt
self.format_prompt = format_prompt
self.max_pixels = max_pixels
self.min_pixels = min_pixels
......@@ -123,8 +123,8 @@ class RLHFDataset(Dataset, ImageProcessMixin):
def __getitem__(self, index):
row_dict: dict = self.dataset[index]
prompt_str: str = row_dict[self.prompt_key]
if self.system_prompt:
prompt_str = " ".join((self.system_prompt.strip(), prompt_str))
if self.format_prompt:
prompt_str = prompt_str + " " + self.format_prompt.strip()
if self.image_key in row_dict:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A Ray logger will receive logging info from different processes.
"""
import numbers
from typing import Dict
def concat_dict_to_str(dict: Dict, step):
output = [f"step {step}:"]
for k, v in dict.items():
if isinstance(v, numbers.Number):
output.append(f"{k}:{v:.3f}")
output_str = " - ".join(output)
return output_str
class LocalLogger:
def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False):
self.print_to_console = print_to_console
if print_to_console:
print("Using LocalLogger is deprecated. The constructor API will change.")
def flush(self):
pass
def log(self, data, step):
if self.print_to_console:
print(concat_dict_to_str(data, step=step), flush=True)
......@@ -31,22 +31,23 @@ if is_package_available("swanlab"):
@dataclass
class GenerationLogger(ABC):
@abstractmethod
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: ...
def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: ...
@dataclass
class ConsoleGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
for inp, out, score in samples:
print(f"[prompt] {inp}\n[output] {out}\n[score] {score}\n")
def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
for inp, out, lab, score in samples:
print(f"[prompt] {inp}\n[output] {out}\n[ground_truth] {lab}\n[score] {score}\n")
@dataclass
class WandbGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
# Create column names for all samples
columns = ["step"] + sum(
[[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
[[f"input_{i + 1}", f"output_{i + 1}", f"label_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))],
[],
)
if not hasattr(self, "validation_table"):
......@@ -69,10 +70,12 @@ class WandbGenerationLogger(GenerationLogger):
@dataclass
class SwanlabGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
swanlab_text_list = []
for i, sample in enumerate(samples):
row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}"
row_text = "\n\n---\n\n".join(
(f"input: {sample[0]}", f"output: {sample[1]}", f"label: {sample[2]}", f"score: {sample[3]}")
)
swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
swanlab.log({"val/generations": swanlab_text_list}, step=step)
......@@ -94,6 +97,6 @@ class AggregateGenerationsLogger:
if logger in GEN_LOGGERS:
self.loggers.append(GEN_LOGGERS[logger]())
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
for logger in self.loggers:
logger.log(samples, step)
......@@ -71,7 +71,7 @@ class TensorBoardLogger(Logger):
os.makedirs(tensorboard_dir, exist_ok=True)
print(f"Saving tensorboard log to {tensorboard_dir}.")
self.writer = SummaryWriter(tensorboard_dir)
self.writer.add_hparams(flatten_dict(config))
self.writer.add_hparams(hparam_dict=flatten_dict(config), metric_dict={})
def log(self, data: Dict[str, Any], step: int) -> None:
for key, value in data.items():
......@@ -146,7 +146,7 @@ class Tracker:
for logger in self.loggers:
logger.log(data=data, step=step)
def log_generation(self, samples: List[Tuple[str, str, float]], step: int) -> None:
def log_generation(self, samples: List[Tuple[str, str, str, float]], step: int) -> None:
self.gen_logger.log(samples, step)
def __del__(self):
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
def log_gpu_memory_usage(head: str, rank: int = 0):
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
memory_allocated = torch.cuda.memory_allocated() / 1024**3
memory_reserved = torch.cuda.memory_reserved() / 1024**3
print(f"{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}.")
......@@ -30,6 +30,7 @@ def math_acc_reward(predict_str: str, ground_truth: str) -> float:
def math_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) # handle qwen2.5vl-32b format
format = math_format_reward(predict_str)
accuracy = math_acc_reward(predict_str, ground_truth)
return {
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import defaultdict
from typing import Any, Dict, List, Optional
import torch
from datasets import load_dataset
from PIL import Image
from PIL.Image import Image as ImageObject
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
import verl.utils.torch_functional as verl_F
from verl.models.transformers.qwen2_5_vl import get_rope_index
def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
tensors = defaultdict(list)
non_tensors = defaultdict(list)
for feature in features:
for key, value in feature.items():
if isinstance(value, torch.Tensor):
tensors[key].append(value)
else:
non_tensors[key].append(value)
for key, value in tensors.items():
if key not in ["pixel_values", "image_grid_thw"]:
tensors[key] = torch.stack(value, dim=0)
return {**tensors, **non_tensors}
def process_image(image: ImageObject, max_pixels: int, min_pixels: int) -> ImageObject:
if (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
if (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
if image.mode != "RGB":
image = image.convert("RGB")
return image
class RLHFDataset(Dataset):
"""
We assume the dataset contains a column that contains prompts and other information
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin],
prompt_key="prompt",
max_prompt_length=1024,
truncation="error",
system_prompt=None,
max_pixels=None,
min_pixels=None,
):
self.tokenizer = tokenizer
self.processor = processor
self.prompt_key = prompt_key
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.system_prompt = system_prompt
self.max_pixels = max_pixels
self.min_pixels = min_pixels
if "@" in data_path:
data_path, data_split = data_path.split("@")
else:
data_split = "train"
self.dataset = load_dataset(data_path, split=data_split)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
"""
Note that we also return the raw_input_ids so that it can be combined with other chat template
"""
row_dict = self.dataset[index]
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": row_dict[self.prompt_key]},
]
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
if "images" in row_dict: # expand image token
raw_prompt = prompt.replace("<image>", "<|vision_start|><|image_pad|><|vision_end|>")
row_dict["images"] = [
process_image(image, self.max_pixels, self.min_pixels) for image in row_dict["images"]
]
image_inputs = self.processor.image_processor(row_dict["images"], return_tensors="pt")
image_grid_thw = image_inputs["image_grid_thw"]
row_dict.update(image_inputs)
if image_grid_thw is not None:
merge_length = self.processor.image_processor.merge_size**2
index = 0
while "<image>" in prompt:
prompt = prompt.replace(
"<image>",
"<|vision_start|>"
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
+ "<|vision_end|>",
1,
)
index += 1
prompt = prompt.replace("<|placeholder|>", self.processor.image_token)
else:
raw_prompt = prompt
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt,
tokenizer=self.tokenizer,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
if "images" in row_dict:
position_ids = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=image_grid_thw,
attention_mask=attention_mask,
) # (3, seq_len)
else:
position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seqlen,)
row_dict["input_ids"] = input_ids
row_dict["attention_mask"] = attention_mask
row_dict["position_ids"] = position_ids
row_dict["raw_prompt_ids"] = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
return row_dict
......@@ -94,27 +94,29 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -
return (values - mean) * torch.rsqrt(var + eps)
def get_eos_mask(response_ids: torch.Tensor, eos_token_id: Union[int, List[int]] = 2, dtype: torch.dtype = torch.long):
def get_response_mask(
response_ids: torch.Tensor, eos_token_id: Union[int, List[int]] = 2, dtype: torch.dtype = torch.long
):
"""Get the mask for the response ids, the mask will be 0 after the first eos token.
eos_token_id can be int or list: 1 or [1, 2].
```
e.g. eos_token = 1
response_ids: [0, 0, 2, 4, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
response_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
```
"""
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_mask = torch.zeros_like(response_ids, dtype=torch.bool)
response_mask = torch.zeros_like(response_ids, dtype=torch.bool)
for token_id in eos_token_id:
eos_mask |= response_ids.eq(token_id)
response_mask |= response_ids.eq(token_id)
eos_mask = eos_mask.long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask
response_mask = response_mask.long()
response_mask = (torch.cumsum(response_mask, dim=1) - response_mask).bool()
response_mask = torch.logical_not(response_mask).to(dtype)
return response_mask
def pad_2d_list_to_length(
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A unified tracking interface that supports logging data to different backend
"""
import os
from typing import List, Union
from verl.utils.logger.aggregate_logger import LocalLogger
class Tracking:
supported_backend = ["wandb", "mlflow", "swanlab", "console"]
def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None):
if isinstance(default_backend, str):
default_backend = [default_backend]
for backend in default_backend:
assert backend in self.supported_backend, f"{backend} is not supported"
self.logger = {}
if "wandb" in default_backend:
import wandb # type: ignore
wandb.init(project=project_name, name=experiment_name, config=config)
self.logger["wandb"] = wandb
if "mlflow" in default_backend:
import mlflow # type: ignore
mlflow.start_run(run_name=experiment_name)
mlflow.log_params(config)
self.logger["mlflow"] = _MlflowLoggingAdapter()
if "swanlab" in default_backend:
import swanlab # type: ignore
SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None)
SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog")
SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud")
if SWANLAB_API_KEY:
swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten
swanlab.init(
project=project_name,
experiment_name=experiment_name,
config=config,
logdir=SWANLAB_LOG_DIR,
mode=SWANLAB_MODE,
)
self.logger["swanlab"] = swanlab
if "console" in default_backend:
self.console_logger = LocalLogger(print_to_console=True)
self.logger["console"] = self.console_logger
def log(self, data, step, backend=None):
for default_backend, logger_instance in self.logger.items():
if backend is None or default_backend in backend:
logger_instance.log(data=data, step=step)
def __del__(self):
if "wandb" in self.logger:
self.logger["wandb"].finish(exit_code=0)
if "swanlab" in self.logger:
self.logger["swanlab"].finish()
class _MlflowLoggingAdapter:
def log(self, data, step):
import mlflow # type: ignore
mlflow.log_metrics(metrics=data, step=step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment