Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
import os import os
import warnings import warnings
from typing import Optional, Union
import torch import torch
import torch.distributed import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from .checkpoint_manager import BaseCheckpointManager from .checkpoint_manager import BaseCheckpointManager
...@@ -44,65 +45,56 @@ class FSDPCheckpointManager(BaseCheckpointManager): ...@@ -44,65 +45,56 @@ class FSDPCheckpointManager(BaseCheckpointManager):
model: FSDP, model: FSDP,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer, processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
processor: ProcessorMixin,
*args,
**kwargs,
): ):
super().__init__(model, optimizer, lr_scheduler, tokenizer, processor) super().__init__(model, optimizer, lr_scheduler, processing_class)
def load_checkpoint(self, path=None, *args, **kwargs): def load_checkpoint(self, path: Optional[str] = None):
if path is None: if path is None:
return return
# every rank download its own checkpoint # every rank download its own checkpoint
local_model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
local_optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
local_extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print( print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
f"[rank-{self.rank}]: Loading from {local_model_path} and {local_optim_path} and {local_extra_state_path}" model_state_dict = torch.load(model_path, weights_only=False)
) optimizer_state_dict = torch.load(optim_path, weights_only=False)
model_state_dict = torch.load(local_model_path) extra_state_dict = torch.load(extra_state_path, weights_only=False)
optimizer_state_dict = torch.load(local_optim_path)
extra_state_dict = torch.load(local_extra_state_path)
lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): with warnings.catch_warnings():
self.model.load_state_dict(model_state_dict) warnings.simplefilter("ignore")
if self.optimizer is not None: with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
self.optimizer.load_state_dict(optimizer_state_dict) self.model.load_state_dict(model_state_dict)
# recover random state if self.optimizer is not None:
if "rng" in extra_state_dict: self.optimizer.load_state_dict(optimizer_state_dict)
# 'rng' may not exist for backward compatibility
self.load_rng_state(extra_state_dict["rng"])
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs): # recover random state
# record the previous global step if "rng" in extra_state_dict:
self.previous_global_step = global_step self.load_rng_state(extra_state_dict["rng"])
# remove previous local_path def save_checkpoint(self, path: str):
# TODO: shall we remove previous ckpt every save? path = self.local_mkdir(path)
if remove_previous_ckpt: dist.barrier()
self.remove_previous_save_local_path()
local_path = self.local_mkdir(local_path)
torch.distributed.barrier()
# every rank will save its own model and optim shard # every rank will save its own model and optim shard
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
model_state_dict = self.model.state_dict() model_state_dict = self.model.state_dict()
if self.optimizer is not None: if self.optimizer is not None:
optimizer_state_dict = self.optimizer.state_dict() optimizer_state_dict = self.optimizer.state_dict()
else: else:
optimizer_state_dict = None optimizer_state_dict = None
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
lr_scheduler_state_dict = self.lr_scheduler.state_dict() lr_scheduler_state_dict = self.lr_scheduler.state_dict()
else: else:
...@@ -112,29 +104,28 @@ class FSDPCheckpointManager(BaseCheckpointManager): ...@@ -112,29 +104,28 @@ class FSDPCheckpointManager(BaseCheckpointManager):
"lr_scheduler": lr_scheduler_state_dict, "lr_scheduler": lr_scheduler_state_dict,
"rng": self.get_rng_state(), "rng": self.get_rng_state(),
} }
model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}") print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}") print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}") print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
torch.save(model_state_dict, model_path) torch.save(model_state_dict, model_path)
torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None if self.optimizer is not None:
torch.save(optimizer_state_dict, optim_path)
torch.save(extra_state_dict, extra_path) torch.save(extra_state_dict, extra_path)
# wait for everyone to dump to local # wait for everyone to dump to local
torch.distributed.barrier() dist.barrier()
if self.rank == 0: if self.rank == 0:
hf_local_path = os.path.join(local_path, "huggingface") hf_path = os.path.join(path, "huggingface")
os.makedirs(hf_local_path, exist_ok=True) os.makedirs(hf_path, exist_ok=True)
self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path) assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel)
if self.processor: self.model._fsdp_wrapped_module.config.save_pretrained(hf_path)
self.processor.save_pretrained(hf_local_path) self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path)
else: self.processing_class.save_pretrained(hf_path)
self.tokenizer.save_pretrained(hf_local_path)
dist.barrier()
torch.distributed.barrier()
self.previous_save_local_path = local_path
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from collections import defaultdict
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from PIL.Image import Image as ImageObject
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..models.transformers.qwen2_vl import get_rope_index
from . import torch_functional as VF
def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
tensors = defaultdict(list)
non_tensors = defaultdict(list)
for feature in features:
for key, value in feature.items():
if isinstance(value, torch.Tensor):
tensors[key].append(value)
else:
non_tensors[key].append(value)
for key, value in tensors.items():
tensors[key] = torch.stack(value, dim=0)
for key, value in non_tensors.items():
non_tensors[key] = np.array(value, dtype=object)
return {**tensors, **non_tensors}
class ImageProcessMixin:
max_pixels: int
min_pixels: int
def process_image(self, image: Union[Dict[str, Any], ImageObject]) -> ImageObject:
if isinstance(image, dict):
image = Image.open(BytesIO(image["bytes"]))
elif isinstance(image, bytes):
image = Image.open(BytesIO(image))
if (image.width * image.height) > self.max_pixels:
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if (image.width * image.height) < self.min_pixels:
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
return image
class RLHFDataset(Dataset, ImageProcessMixin):
"""
We assume the dataset contains a column that contains prompts and other information
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin],
prompt_key: str = "prompt",
answer_key: str = "answer",
image_key: str = "images",
max_prompt_length: int = 1024,
truncation: str = "error",
system_prompt: str = None,
max_pixels: int = None,
min_pixels: int = None,
):
self.tokenizer = tokenizer
self.processor = processor
self.prompt_key = prompt_key
self.answer_key = answer_key
self.image_key = image_key
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.system_prompt = system_prompt
self.max_pixels = max_pixels
self.min_pixels = min_pixels
if "@" in data_path:
data_path, data_split = data_path.split("@")
else:
data_split = "train"
if os.path.isdir(data_path):
self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
elif os.path.isfile(data_path):
self.dataset = load_dataset("parquet", data_files=data_path, split="train")
else: # remote dataset
self.dataset = load_dataset(data_path, split=data_split)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
row_dict: dict = self.dataset[index]
prompt_str: str = row_dict[self.prompt_key]
if self.system_prompt:
prompt_str = " ".join((self.system_prompt.strip(), prompt_str))
if self.image_key in row_dict:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list = []
for i, content in enumerate(prompt_str.split("<image>")):
if i != 0:
content_list.append({"type": "image"})
if content:
content_list.append({"type": "text", "text": content})
messages = [{"role": "user", "content": content_list}]
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
row_dict["multi_modal_data"] = {"image": images}
row_dict["multi_modal_inputs"] = dict(model_inputs)
# qwen2vl mrope
position_ids = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=model_inputs["image_grid_thw"],
attention_mask=attention_mask,
) # (3, seq_length)
else:
messages = [{"role": "user", "content": prompt_str}]
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
input_ids, attention_mask, position_ids = VF.postprocess_data(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
row_dict["input_ids"] = input_ids
row_dict["attention_mask"] = attention_mask
row_dict["position_ids"] = position_ids
row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False)
row_dict["ground_truth"] = row_dict.pop(self.answer_key)
return row_dict
...@@ -12,22 +12,29 @@ ...@@ -12,22 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Tuple
import torch import torch
from transformers import LlamaConfig, PretrainedConfig, Qwen2Config
VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig) if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
def get_device_flops(unit="T"):
def unit_convert(number, level): def get_device_flops(unit: str = "T") -> float:
def unit_convert(number: float, level: str):
units = ["B", "K", "M", "G", "T", "P"] units = ["B", "K", "M", "G", "T", "P"]
if number <= 0: if number <= 0:
return number return number
ptr = 0 ptr = 0
while ptr < len(units) and units[ptr] != level: while ptr < len(units) and units[ptr] != level:
number /= 1000 number /= 1000
ptr += 1 ptr += 1
return number return number
device_name = torch.cuda.get_device_name() device_name = torch.cuda.get_device_name()
...@@ -55,21 +62,24 @@ class FlopsCounter: ...@@ -55,21 +62,24 @@ class FlopsCounter:
Example: Example:
flops_counter = FlopsCounter(config) flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time) flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
""" """
def __init__(self, config: PretrainedConfig): def __init__(self, config: "LlamaConfig"):
if not isinstance(config, VALID_CONFIG_TYPE): if config.model_type not in VALID_MODLE_TYPE:
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. MFU will always be zero.") print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.")
self.estimate_func = {"qwen2": self._estimate_qwen2_flops, "llama": self._estimate_qwen2_flops} self.estimate_func = {
"llama": self._estimate_llama_flops,
"qwen2": self._estimate_llama_flops,
"qwen2_vl": self._estimate_llama_flops,
"qwen2_5_vl": self._estimate_llama_flops,
}
self.config = config self.config = config
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time): def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
return 0 return 0
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time): def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers num_hidden_layers = self.config.num_hidden_layers
...@@ -96,6 +106,7 @@ class FlopsCounter: ...@@ -96,6 +106,7 @@ class FlopsCounter:
seqlen_square_sum = 0 seqlen_square_sum = 0
for seqlen in batch_seqlens: for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
# all_layer & all_token fwd & bwd flops # all_layer & all_token fwd & bwd flops
...@@ -103,7 +114,7 @@ class FlopsCounter: ...@@ -103,7 +114,7 @@ class FlopsCounter:
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved return flops_achieved
def estimate_flops(self, batch_seqlens, delta_time): def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]:
""" """
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from typing import Callable, Union from typing import Callable, Union
...@@ -73,6 +74,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True): ...@@ -73,6 +74,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
for handle in model._all_handles: for handle in model._all_handles:
if handle._offload_params: if handle._offload_params:
continue continue
flat_param = handle.flat_param flat_param = handle.flat_param
assert ( assert (
flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()
...@@ -89,7 +91,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True): ...@@ -89,7 +91,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
@torch.no_grad() @torch.no_grad()
def load_fsdp_model(model: FSDP): def load_fsdp_model(model: FSDP, empty_cache: bool = True):
# lazy init FSDP model # lazy init FSDP model
_lazy_init(model, model) _lazy_init(model, model)
assert model._is_root, "Only support root model loading to GPU" assert model._is_root, "Only support root model loading to GPU"
...@@ -102,11 +104,15 @@ def load_fsdp_model(model: FSDP): ...@@ -102,11 +104,15 @@ def load_fsdp_model(model: FSDP):
# the following still keeps id(._local_shard) != id(.data) # the following still keeps id(._local_shard) != id(.data)
flat_param._local_shard = flat_param.data flat_param._local_shard = flat_param.data
if empty_cache:
gc.collect()
@torch.no_grad() @torch.no_grad()
def offload_fsdp_optimizer(optimizer: Optimizer): def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
if not optimizer.state: if not optimizer.state:
return return
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
for param in param_group["params"]: for param in param_group["params"]:
state = optimizer.state[param] state = optimizer.state[param]
...@@ -114,14 +120,21 @@ def offload_fsdp_optimizer(optimizer: Optimizer): ...@@ -114,14 +120,21 @@ def offload_fsdp_optimizer(optimizer: Optimizer):
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True) state[key] = value.to("cpu", non_blocking=True)
if empty_cache:
torch.cuda.empty_cache()
@torch.no_grad() @torch.no_grad()
def load_fsdp_optimizer(optimizer: Optimizer): def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
if not optimizer.state: if not optimizer.state:
return return
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
for param in param_group["params"]: for param in param_group["params"]:
state = optimizer.state[param] state = optimizer.state[param]
for key, value in state.items(): for key, value in state.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
state[key] = value.to("cuda", non_blocking=True) state[key] = value.to("cuda", non_blocking=True)
if empty_cache:
gc.collect()
...@@ -11,3 +11,9 @@ ...@@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .logger import Tracker
__all__ = ["Tracker"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Tuple
from ..py_functional import is_package_available
if is_package_available("wandb"):
import wandb # type: ignore
if is_package_available("swanlab"):
import swanlab # type: ignore
@dataclass
class GenerationLogger(ABC):
@abstractmethod
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: ...
@dataclass
class ConsoleGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
for inp, out, score in samples:
print(f"[prompt] {inp}\n[output] {out}\n[score] {score}\n")
@dataclass
class WandbGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
# Create column names for all samples
columns = ["step"] + sum(
[[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
)
if not hasattr(self, "validation_table"):
# Initialize the table on first call
self.validation_table = wandb.Table(columns=columns)
# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self.validation_table.data)
# Add new row with all data
row_data = [step]
for sample in samples:
row_data.extend(sample)
new_table.add_data(*row_data)
wandb.log({"val/generations": new_table}, step=step)
self.validation_table = new_table
@dataclass
class SwanlabGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
swanlab_text_list = []
for i, sample in enumerate(samples):
row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}"
swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
swanlab.log({"val/generations": swanlab_text_list}, step=step)
GEN_LOGGERS = {
"console": ConsoleGenerationLogger,
"wandb": WandbGenerationLogger,
"swanlab": SwanlabGenerationLogger,
}
@dataclass
class AggregateGenerationsLogger:
def __init__(self, loggers: List[str]):
self.loggers: List[GenerationLogger] = []
for logger in loggers:
if logger in GEN_LOGGERS:
self.loggers.append(GEN_LOGGERS[logger]())
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
for logger in self.loggers:
logger.log(samples, step)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A unified tracking interface that supports logging data to different backend
"""
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from torch.utils.tensorboard import SummaryWriter
from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict
from .gen_logger import AggregateGenerationsLogger
if is_package_available("mlflow"):
import mlflow # type: ignore
if is_package_available("wandb"):
import wandb # type: ignore
if is_package_available("swanlab"):
import swanlab # type: ignore
class Logger(ABC):
@abstractmethod
def __init__(self, config: Dict[str, Any]) -> None: ...
@abstractmethod
def log(self, data: Dict[str, Any], step: int) -> None: ...
def finish(self) -> None:
pass
class ConsoleLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
print("Config\n" + convert_dict_to_str(config))
def log(self, data: Dict[str, Any], step: int) -> None:
print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data)))
class MlflowLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
mlflow.start_run(run_name=config["trainer"]["experiment_name"])
mlflow.log_params(flatten_dict(config))
def log(self, data: Dict[str, Any], step: int) -> None:
mlflow.log_metrics(metrics=data, step=step)
class TensorBoardLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log")
os.makedirs(tensorboard_dir, exist_ok=True)
print(f"Saving tensorboard log to {tensorboard_dir}.")
self.writer = SummaryWriter(tensorboard_dir)
self.writer.add_hparams(flatten_dict(config))
def log(self, data: Dict[str, Any], step: int) -> None:
for key, value in data.items():
self.writer.add_scalar(key, value, step)
def finish(self):
self.writer.close()
class WandbLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
wandb.init(
project=config["trainer"]["project_name"],
name=config["trainer"]["experiment_name"],
config=config,
)
def log(self, data: Dict[str, Any], step: int) -> None:
wandb.log(data=data, step=step)
def finish(self) -> None:
wandb.finish()
class SwanlabLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
swanlab_key = os.getenv("SWANLAB_API_KEY")
swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log")
swanlab_mode = os.getenv("SWANLAB_MODE", "cloud")
if swanlab_key:
swanlab.login(swanlab_key)
swanlab.init(
project=config["trainer"]["project_name"],
experiment_name=config["trainer"]["experiment_name"],
config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config},
logdir=swanlab_dir,
mode=swanlab_mode,
)
def log(self, data: Dict[str, Any], step: int) -> None:
swanlab.log(data=data, step=step)
def finish(self) -> None:
swanlab.finish()
LOGGERS = {
"wandb": WandbLogger,
"mlflow": MlflowLogger,
"tensorboard": TensorBoardLogger,
"console": ConsoleLogger,
"swanlab": SwanlabLogger,
}
class Tracker:
def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None):
if isinstance(loggers, str):
loggers = [loggers]
self.loggers: List[Logger] = []
for logger in loggers:
if logger not in LOGGERS:
raise ValueError(f"{logger} is not supported.")
self.loggers.append(LOGGERS[logger](config))
self.gen_logger = AggregateGenerationsLogger(loggers)
def log(self, data: Dict[str, Any], step: int) -> None:
for logger in self.loggers:
logger.log(data=data, step=step)
def log_generation(self, samples: List[Tuple[str, str, float]], step: int) -> None:
self.gen_logger.log(samples, step)
def __del__(self):
for logger in self.loggers:
logger.finish()
...@@ -15,11 +15,28 @@ ...@@ -15,11 +15,28 @@
Utilities to create common models Utilities to create common models
""" """
from functools import lru_cache
from typing import Optional, Tuple
import torch import torch
import torch.distributed as dist
from torch import nn from torch import nn
def get_model_size(model: nn.Module, scale="auto"): @lru_cache
def is_rank0() -> int:
return (not dist.is_initialized()) or (dist.get_rank() == 0)
def print_gpu_memory_usage(prefix: str = "GPU memory usage") -> None:
"""Report the current GPU VRAM usage."""
if is_rank0():
free_mem, total_mem = torch.cuda.mem_get_info()
print(f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB.")
def _get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]:
"""Compute the model size."""
n_params = sum(p.numel() for p in model.parameters()) n_params = sum(p.numel() for p in model.parameters())
if scale == "auto": if scale == "auto":
...@@ -41,18 +58,16 @@ def get_model_size(model: nn.Module, scale="auto"): ...@@ -41,18 +58,16 @@ def get_model_size(model: nn.Module, scale="auto"):
elif scale == "": elif scale == "":
pass pass
else: else:
raise NotImplementedError(f"Unknown scale {scale}") raise NotImplementedError(f"Unknown scale {scale}.")
return n_params, scale return n_params, scale
def print_model_size(model: nn.Module, name: str = None): def print_model_size(model: nn.Module, name: Optional[str] = None) -> None:
n_params, scale = get_model_size(model, scale="auto") """Print the model size."""
if name is None: if is_rank0():
name = model.__class__.__name__ n_params, scale = _get_model_size(model, scale="auto")
if name is None:
print(f"{name} contains {n_params:.2f}{scale} parameters") name = model.__class__.__name__
def compute_position_id_with_mask(mask): print(f"{name} contains {n_params:.2f}{scale} parameters.")
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
...@@ -15,23 +15,89 @@ ...@@ -15,23 +15,89 @@
Contain small python utility functions Contain small python utility functions
""" """
from typing import Any, Dict, List import importlib.util
import re
from functools import lru_cache
from typing import Any, Dict, List, Union
import numpy as np
import yaml
from yaml import Dumper
def is_sci_notation(number: float) -> bool:
pattern = re.compile(r"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$")
return bool(pattern.match(str(number)))
def float_representer(dumper: Dumper, number: Union[float, np.float32, np.float64]):
if is_sci_notation(number):
value = str(number)
if "." not in value and "e" in value:
value = value.replace("e", ".0e", 1)
else:
value = str(round(number, 3))
return dumper.represent_scalar("tag:yaml.org,2002:float", value)
yaml.add_representer(float, float_representer)
yaml.add_representer(np.float32, float_representer)
yaml.add_representer(np.float64, float_representer)
@lru_cache
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
"""Union two dict. Will throw an error if there is an item not the same object with the same key.""" """Union two dict. Will throw an error if there is an item not the same object with the same key."""
for key, value in dict2.items(): for key in dict2.keys():
if key in dict1: if key in dict1:
assert dict1[key] != value, f"{key} in meta_dict1 and meta_dict2 are not the same object" assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object"
dict1[key] = value dict1[key] = dict2[key]
return dict1 return dict1
def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None: def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None:
"""Append dict to a dict of list."""
for key, val in new_data.items(): for key, val in new_data.items():
if key not in data: if key not in data:
data[key] = [] data[key] = []
data[key].append(val) data[key].append(val)
def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]:
unflattened = {}
for key, value in data.items():
pieces = key.split(sep)
pointer = unflattened
for piece in pieces[:-1]:
if piece not in pointer:
pointer[piece] = {}
pointer = pointer[piece]
pointer[pieces[-1]] = value
return unflattened
def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]:
flattened = {}
for key, value in data.items():
new_key = parent_key + sep + key if parent_key else key
if isinstance(value, dict):
flattened.update(flatten_dict(value, new_key, sep=sep))
else:
flattened[new_key] = value
return flattened
def convert_dict_to_str(data: Dict[str, Any]) -> str:
return yaml.dump(data, indent=2)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict
from mathruler.grader import extract_boxed_content, grade_answer from mathruler.grader import extract_boxed_content, grade_answer
def math_compute_score(predict_str: str, ground_truth: str) -> float: def math_format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0
def math_acc_reward(predict_str: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict_str) answer = extract_boxed_content(predict_str)
if answer == "None": return 1.0 if grade_answer(answer, ground_truth) else 0.0
return 0.0 # no answer
if grade_answer(answer, ground_truth):
return 1.0 # correct answer
return 0.1 # wrong answer def math_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
format = math_format_reward(predict_str)
accuracy = math_acc_reward(predict_str, ground_truth)
return {
"overall": 0.9 * accuracy + 0.1 * format,
"format": format,
"accuracy": accuracy,
}
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re import re
from typing import Dict
from mathruler.grader import grade_answer from mathruler.grader import grade_answer
def r1v_format_reward(predict_str: str) -> float: def r1v_format_reward(predict_str: str) -> float:
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>" pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
match = re.fullmatch(pattern, predict_str, re.DOTALL) format_match = re.fullmatch(pattern, predict_str)
return 1.0 if match else 0.0 return 1.0 if format_match else 0.0
def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float: def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
try: try:
ground_truth = ground_truth.strip() ground_truth = ground_truth.strip()
content_match = re.search(r"<answer>(.*?)</answer>", predict_str) content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
pred_answer = content_match.group(1).strip() if content_match else predict_str.strip() given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(pred_answer, ground_truth): if grade_answer(given_answer, ground_truth):
return 1.0 return 1.0
except Exception: except Exception:
pass pass
return 0.0 return 0.0
def r1v_compute_score(predict_str: str, ground_truth: str) -> float: def r1v_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
acc_reward = r1v_accuracy_reward(predict_str, ground_truth) format = r1v_format_reward(predict_str)
format_reward = r1v_format_reward(predict_str) accuracy = r1v_accuracy_reward(predict_str, ground_truth)
reward = acc_reward + format_reward return {
reward /= 2 "overall": 0.5 * accuracy + 0.5 * format,
return reward "format": format,
"accuracy": accuracy,
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment