Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
......@@ -14,12 +14,13 @@
import os
import warnings
from typing import Optional, Union
import torch
import torch.distributed
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from .checkpoint_manager import BaseCheckpointManager
......@@ -44,65 +45,56 @@ class FSDPCheckpointManager(BaseCheckpointManager):
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin,
*args,
**kwargs,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
):
super().__init__(model, optimizer, lr_scheduler, tokenizer, processor)
super().__init__(model, optimizer, lr_scheduler, processing_class)
def load_checkpoint(self, path=None, *args, **kwargs):
def load_checkpoint(self, path: Optional[str] = None):
if path is None:
return
# every rank download its own checkpoint
local_model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
local_optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
local_extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(
f"[rank-{self.rank}]: Loading from {local_model_path} and {local_optim_path} and {local_extra_state_path}"
)
model_state_dict = torch.load(local_model_path)
optimizer_state_dict = torch.load(local_optim_path)
extra_state_dict = torch.load(local_extra_state_path)
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
model_state_dict = torch.load(model_path, weights_only=False)
optimizer_state_dict = torch.load(optim_path, weights_only=False)
extra_state_dict = torch.load(extra_state_path, weights_only=False)
lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
# recover random state
if "rng" in extra_state_dict:
# 'rng' may not exist for backward compatibility
self.load_rng_state(extra_state_dict["rng"])
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs):
# record the previous global step
self.previous_global_step = global_step
# recover random state
if "rng" in extra_state_dict:
self.load_rng_state(extra_state_dict["rng"])
# remove previous local_path
# TODO: shall we remove previous ckpt every save?
if remove_previous_ckpt:
self.remove_previous_save_local_path()
local_path = self.local_mkdir(local_path)
torch.distributed.barrier()
def save_checkpoint(self, path: str):
path = self.local_mkdir(path)
dist.barrier()
# every rank will save its own model and optim shard
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
model_state_dict = self.model.state_dict()
if self.optimizer is not None:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
if self.lr_scheduler is not None:
lr_scheduler_state_dict = self.lr_scheduler.state_dict()
else:
......@@ -112,29 +104,28 @@ class FSDPCheckpointManager(BaseCheckpointManager):
"lr_scheduler": lr_scheduler_state_dict,
"rng": self.get_rng_state(),
}
model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}")
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
torch.save(model_state_dict, model_path)
torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None
if self.optimizer is not None:
torch.save(optimizer_state_dict, optim_path)
torch.save(extra_state_dict, extra_path)
# wait for everyone to dump to local
torch.distributed.barrier()
dist.barrier()
if self.rank == 0:
hf_local_path = os.path.join(local_path, "huggingface")
os.makedirs(hf_local_path, exist_ok=True)
self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path)
if self.processor:
self.processor.save_pretrained(hf_local_path)
else:
self.tokenizer.save_pretrained(hf_local_path)
torch.distributed.barrier()
self.previous_save_local_path = local_path
hf_path = os.path.join(path, "huggingface")
os.makedirs(hf_path, exist_ok=True)
assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel)
self.model._fsdp_wrapped_module.config.save_pretrained(hf_path)
self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path)
self.processing_class.save_pretrained(hf_path)
dist.barrier()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from collections import defaultdict
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from PIL.Image import Image as ImageObject
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..models.transformers.qwen2_vl import get_rope_index
from . import torch_functional as VF
def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
tensors = defaultdict(list)
non_tensors = defaultdict(list)
for feature in features:
for key, value in feature.items():
if isinstance(value, torch.Tensor):
tensors[key].append(value)
else:
non_tensors[key].append(value)
for key, value in tensors.items():
tensors[key] = torch.stack(value, dim=0)
for key, value in non_tensors.items():
non_tensors[key] = np.array(value, dtype=object)
return {**tensors, **non_tensors}
class ImageProcessMixin:
max_pixels: int
min_pixels: int
def process_image(self, image: Union[Dict[str, Any], ImageObject]) -> ImageObject:
if isinstance(image, dict):
image = Image.open(BytesIO(image["bytes"]))
elif isinstance(image, bytes):
image = Image.open(BytesIO(image))
if (image.width * image.height) > self.max_pixels:
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if (image.width * image.height) < self.min_pixels:
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
return image
class RLHFDataset(Dataset, ImageProcessMixin):
"""
We assume the dataset contains a column that contains prompts and other information
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin],
prompt_key: str = "prompt",
answer_key: str = "answer",
image_key: str = "images",
max_prompt_length: int = 1024,
truncation: str = "error",
system_prompt: str = None,
max_pixels: int = None,
min_pixels: int = None,
):
self.tokenizer = tokenizer
self.processor = processor
self.prompt_key = prompt_key
self.answer_key = answer_key
self.image_key = image_key
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.system_prompt = system_prompt
self.max_pixels = max_pixels
self.min_pixels = min_pixels
if "@" in data_path:
data_path, data_split = data_path.split("@")
else:
data_split = "train"
if os.path.isdir(data_path):
self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
elif os.path.isfile(data_path):
self.dataset = load_dataset("parquet", data_files=data_path, split="train")
else: # remote dataset
self.dataset = load_dataset(data_path, split=data_split)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
row_dict: dict = self.dataset[index]
prompt_str: str = row_dict[self.prompt_key]
if self.system_prompt:
prompt_str = " ".join((self.system_prompt.strip(), prompt_str))
if self.image_key in row_dict:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list = []
for i, content in enumerate(prompt_str.split("<image>")):
if i != 0:
content_list.append({"type": "image"})
if content:
content_list.append({"type": "text", "text": content})
messages = [{"role": "user", "content": content_list}]
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
row_dict["multi_modal_data"] = {"image": images}
row_dict["multi_modal_inputs"] = dict(model_inputs)
# qwen2vl mrope
position_ids = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=model_inputs["image_grid_thw"],
attention_mask=attention_mask,
) # (3, seq_length)
else:
messages = [{"role": "user", "content": prompt_str}]
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
input_ids, attention_mask, position_ids = VF.postprocess_data(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
row_dict["input_ids"] = input_ids
row_dict["attention_mask"] = attention_mask
row_dict["position_ids"] = position_ids
row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False)
row_dict["ground_truth"] = row_dict.pop(self.answer_key)
return row_dict
......@@ -12,22 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Tuple
import torch
from transformers import LlamaConfig, PretrainedConfig, Qwen2Config
VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig)
if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
def get_device_flops(unit="T"):
def unit_convert(number, level):
def get_device_flops(unit: str = "T") -> float:
def unit_convert(number: float, level: str):
units = ["B", "K", "M", "G", "T", "P"]
if number <= 0:
return number
ptr = 0
while ptr < len(units) and units[ptr] != level:
number /= 1000
ptr += 1
return number
device_name = torch.cuda.get_device_name()
......@@ -55,21 +62,24 @@ class FlopsCounter:
Example:
flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
"""
def __init__(self, config: PretrainedConfig):
if not isinstance(config, VALID_CONFIG_TYPE):
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. MFU will always be zero.")
def __init__(self, config: "LlamaConfig"):
if config.model_type not in VALID_MODLE_TYPE:
print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.")
self.estimate_func = {"qwen2": self._estimate_qwen2_flops, "llama": self._estimate_qwen2_flops}
self.estimate_func = {
"llama": self._estimate_llama_flops,
"qwen2": self._estimate_llama_flops,
"qwen2_vl": self._estimate_llama_flops,
"qwen2_5_vl": self._estimate_llama_flops,
}
self.config = config
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
return 0
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
......@@ -96,6 +106,7 @@ class FlopsCounter:
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
# all_layer & all_token fwd & bwd flops
......@@ -103,7 +114,7 @@ class FlopsCounter:
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved
def estimate_flops(self, batch_seqlens, delta_time):
def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]:
"""
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
from collections import defaultdict
from functools import partial
from typing import Callable, Union
......@@ -73,6 +74,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
for handle in model._all_handles:
if handle._offload_params:
continue
flat_param = handle.flat_param
assert (
flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()
......@@ -89,7 +91,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
@torch.no_grad()
def load_fsdp_model(model: FSDP):
def load_fsdp_model(model: FSDP, empty_cache: bool = True):
# lazy init FSDP model
_lazy_init(model, model)
assert model._is_root, "Only support root model loading to GPU"
......@@ -102,11 +104,15 @@ def load_fsdp_model(model: FSDP):
# the following still keeps id(._local_shard) != id(.data)
flat_param._local_shard = flat_param.data
if empty_cache:
gc.collect()
@torch.no_grad()
def offload_fsdp_optimizer(optimizer: Optimizer):
def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
if not optimizer.state:
return
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
......@@ -114,14 +120,21 @@ def offload_fsdp_optimizer(optimizer: Optimizer):
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
if empty_cache:
torch.cuda.empty_cache()
@torch.no_grad()
def load_fsdp_optimizer(optimizer: Optimizer):
def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
if not optimizer.state:
return
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cuda", non_blocking=True)
if empty_cache:
gc.collect()
......@@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .logger import Tracker
__all__ = ["Tracker"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Tuple
from ..py_functional import is_package_available
if is_package_available("wandb"):
import wandb # type: ignore
if is_package_available("swanlab"):
import swanlab # type: ignore
@dataclass
class GenerationLogger(ABC):
@abstractmethod
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: ...
@dataclass
class ConsoleGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
for inp, out, score in samples:
print(f"[prompt] {inp}\n[output] {out}\n[score] {score}\n")
@dataclass
class WandbGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
# Create column names for all samples
columns = ["step"] + sum(
[[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
)
if not hasattr(self, "validation_table"):
# Initialize the table on first call
self.validation_table = wandb.Table(columns=columns)
# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self.validation_table.data)
# Add new row with all data
row_data = [step]
for sample in samples:
row_data.extend(sample)
new_table.add_data(*row_data)
wandb.log({"val/generations": new_table}, step=step)
self.validation_table = new_table
@dataclass
class SwanlabGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
swanlab_text_list = []
for i, sample in enumerate(samples):
row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}"
swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
swanlab.log({"val/generations": swanlab_text_list}, step=step)
GEN_LOGGERS = {
"console": ConsoleGenerationLogger,
"wandb": WandbGenerationLogger,
"swanlab": SwanlabGenerationLogger,
}
@dataclass
class AggregateGenerationsLogger:
def __init__(self, loggers: List[str]):
self.loggers: List[GenerationLogger] = []
for logger in loggers:
if logger in GEN_LOGGERS:
self.loggers.append(GEN_LOGGERS[logger]())
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
for logger in self.loggers:
logger.log(samples, step)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A unified tracking interface that supports logging data to different backend
"""
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from torch.utils.tensorboard import SummaryWriter
from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict
from .gen_logger import AggregateGenerationsLogger
if is_package_available("mlflow"):
import mlflow # type: ignore
if is_package_available("wandb"):
import wandb # type: ignore
if is_package_available("swanlab"):
import swanlab # type: ignore
class Logger(ABC):
@abstractmethod
def __init__(self, config: Dict[str, Any]) -> None: ...
@abstractmethod
def log(self, data: Dict[str, Any], step: int) -> None: ...
def finish(self) -> None:
pass
class ConsoleLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
print("Config\n" + convert_dict_to_str(config))
def log(self, data: Dict[str, Any], step: int) -> None:
print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data)))
class MlflowLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
mlflow.start_run(run_name=config["trainer"]["experiment_name"])
mlflow.log_params(flatten_dict(config))
def log(self, data: Dict[str, Any], step: int) -> None:
mlflow.log_metrics(metrics=data, step=step)
class TensorBoardLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log")
os.makedirs(tensorboard_dir, exist_ok=True)
print(f"Saving tensorboard log to {tensorboard_dir}.")
self.writer = SummaryWriter(tensorboard_dir)
self.writer.add_hparams(flatten_dict(config))
def log(self, data: Dict[str, Any], step: int) -> None:
for key, value in data.items():
self.writer.add_scalar(key, value, step)
def finish(self):
self.writer.close()
class WandbLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
wandb.init(
project=config["trainer"]["project_name"],
name=config["trainer"]["experiment_name"],
config=config,
)
def log(self, data: Dict[str, Any], step: int) -> None:
wandb.log(data=data, step=step)
def finish(self) -> None:
wandb.finish()
class SwanlabLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
swanlab_key = os.getenv("SWANLAB_API_KEY")
swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log")
swanlab_mode = os.getenv("SWANLAB_MODE", "cloud")
if swanlab_key:
swanlab.login(swanlab_key)
swanlab.init(
project=config["trainer"]["project_name"],
experiment_name=config["trainer"]["experiment_name"],
config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config},
logdir=swanlab_dir,
mode=swanlab_mode,
)
def log(self, data: Dict[str, Any], step: int) -> None:
swanlab.log(data=data, step=step)
def finish(self) -> None:
swanlab.finish()
LOGGERS = {
"wandb": WandbLogger,
"mlflow": MlflowLogger,
"tensorboard": TensorBoardLogger,
"console": ConsoleLogger,
"swanlab": SwanlabLogger,
}
class Tracker:
def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None):
if isinstance(loggers, str):
loggers = [loggers]
self.loggers: List[Logger] = []
for logger in loggers:
if logger not in LOGGERS:
raise ValueError(f"{logger} is not supported.")
self.loggers.append(LOGGERS[logger](config))
self.gen_logger = AggregateGenerationsLogger(loggers)
def log(self, data: Dict[str, Any], step: int) -> None:
for logger in self.loggers:
logger.log(data=data, step=step)
def log_generation(self, samples: List[Tuple[str, str, float]], step: int) -> None:
self.gen_logger.log(samples, step)
def __del__(self):
for logger in self.loggers:
logger.finish()
......@@ -15,11 +15,28 @@
Utilities to create common models
"""
from functools import lru_cache
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
def get_model_size(model: nn.Module, scale="auto"):
@lru_cache
def is_rank0() -> int:
return (not dist.is_initialized()) or (dist.get_rank() == 0)
def print_gpu_memory_usage(prefix: str = "GPU memory usage") -> None:
"""Report the current GPU VRAM usage."""
if is_rank0():
free_mem, total_mem = torch.cuda.mem_get_info()
print(f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB.")
def _get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]:
"""Compute the model size."""
n_params = sum(p.numel() for p in model.parameters())
if scale == "auto":
......@@ -41,18 +58,16 @@ def get_model_size(model: nn.Module, scale="auto"):
elif scale == "":
pass
else:
raise NotImplementedError(f"Unknown scale {scale}")
raise NotImplementedError(f"Unknown scale {scale}.")
return n_params, scale
def print_model_size(model: nn.Module, name: str = None):
n_params, scale = get_model_size(model, scale="auto")
if name is None:
name = model.__class__.__name__
print(f"{name} contains {n_params:.2f}{scale} parameters")
def print_model_size(model: nn.Module, name: Optional[str] = None) -> None:
"""Print the model size."""
if is_rank0():
n_params, scale = _get_model_size(model, scale="auto")
if name is None:
name = model.__class__.__name__
def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
print(f"{name} contains {n_params:.2f}{scale} parameters.")
......@@ -15,23 +15,89 @@
Contain small python utility functions
"""
from typing import Any, Dict, List
import importlib.util
import re
from functools import lru_cache
from typing import Any, Dict, List, Union
import numpy as np
import yaml
from yaml import Dumper
def is_sci_notation(number: float) -> bool:
pattern = re.compile(r"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$")
return bool(pattern.match(str(number)))
def float_representer(dumper: Dumper, number: Union[float, np.float32, np.float64]):
if is_sci_notation(number):
value = str(number)
if "." not in value and "e" in value:
value = value.replace("e", ".0e", 1)
else:
value = str(round(number, 3))
return dumper.represent_scalar("tag:yaml.org,2002:float", value)
yaml.add_representer(float, float_representer)
yaml.add_representer(np.float32, float_representer)
yaml.add_representer(np.float64, float_representer)
@lru_cache
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
"""Union two dict. Will throw an error if there is an item not the same object with the same key."""
for key, value in dict2.items():
for key in dict2.keys():
if key in dict1:
assert dict1[key] != value, f"{key} in meta_dict1 and meta_dict2 are not the same object"
assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object"
dict1[key] = value
dict1[key] = dict2[key]
return dict1
def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None:
"""Append dict to a dict of list."""
for key, val in new_data.items():
if key not in data:
data[key] = []
data[key].append(val)
def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]:
unflattened = {}
for key, value in data.items():
pieces = key.split(sep)
pointer = unflattened
for piece in pieces[:-1]:
if piece not in pointer:
pointer[piece] = {}
pointer = pointer[piece]
pointer[pieces[-1]] = value
return unflattened
def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]:
flattened = {}
for key, value in data.items():
new_key = parent_key + sep + key if parent_key else key
if isinstance(value, dict):
flattened.update(flatten_dict(value, new_key, sep=sep))
else:
flattened[new_key] = value
return flattened
def convert_dict_to_str(data: Dict[str, Any]) -> str:
return yaml.dump(data, indent=2)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict
from mathruler.grader import extract_boxed_content, grade_answer
def math_compute_score(predict_str: str, ground_truth: str) -> float:
def math_format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0
def math_acc_reward(predict_str: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict_str)
if answer == "None":
return 0.0 # no answer
return 1.0 if grade_answer(answer, ground_truth) else 0.0
if grade_answer(answer, ground_truth):
return 1.0 # correct answer
return 0.1 # wrong answer
def math_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
format = math_format_reward(predict_str)
accuracy = math_acc_reward(predict_str, ground_truth)
return {
"overall": 0.9 * accuracy + 0.1 * format,
"format": format,
"accuracy": accuracy,
}
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict
from mathruler.grader import grade_answer
def r1v_format_reward(predict_str: str) -> float:
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
match = re.fullmatch(pattern, predict_str, re.DOTALL)
return 1.0 if match else 0.0
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0
def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
try:
ground_truth = ground_truth.strip()
content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
pred_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(pred_answer, ground_truth):
given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(given_answer, ground_truth):
return 1.0
except Exception:
pass
return 0.0
def r1v_compute_score(predict_str: str, ground_truth: str) -> float:
acc_reward = r1v_accuracy_reward(predict_str, ground_truth)
format_reward = r1v_format_reward(predict_str)
reward = acc_reward + format_reward
reward /= 2
return reward
def r1v_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
format = r1v_format_reward(predict_str)
accuracy = r1v_accuracy_reward(predict_str, ground_truth)
return {
"overall": 0.5 * accuracy + 0.5 * format,
"format": format,
"accuracy": accuracy,
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment