Commit c132cbcb authored by chenych's avatar chenych
Browse files

0402 update

parent f92481f0
......@@ -14,12 +14,13 @@
import os
import warnings
from typing import Optional, Union
import torch
import torch.distributed
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from .checkpoint_manager import BaseCheckpointManager
......@@ -44,65 +45,56 @@ class FSDPCheckpointManager(BaseCheckpointManager):
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin,
*args,
**kwargs,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin],
):
super().__init__(model, optimizer, lr_scheduler, tokenizer, processor)
super().__init__(model, optimizer, lr_scheduler, processing_class)
def load_checkpoint(self, path=None, *args, **kwargs):
def load_checkpoint(self, path: Optional[str] = None):
if path is None:
return
# every rank download its own checkpoint
local_model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
local_optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
local_extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(
f"[rank-{self.rank}]: Loading from {local_model_path} and {local_optim_path} and {local_extra_state_path}"
)
model_state_dict = torch.load(local_model_path)
optimizer_state_dict = torch.load(local_optim_path)
extra_state_dict = torch.load(local_extra_state_path)
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
model_state_dict = torch.load(model_path, weights_only=False)
optimizer_state_dict = torch.load(optim_path, weights_only=False)
extra_state_dict = torch.load(extra_state_path, weights_only=False)
lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
# recover random state
if "rng" in extra_state_dict:
# 'rng' may not exist for backward compatibility
self.load_rng_state(extra_state_dict["rng"])
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs):
# record the previous global step
self.previous_global_step = global_step
# recover random state
if "rng" in extra_state_dict:
self.load_rng_state(extra_state_dict["rng"])
# remove previous local_path
# TODO: shall we remove previous ckpt every save?
if remove_previous_ckpt:
self.remove_previous_save_local_path()
local_path = self.local_mkdir(local_path)
torch.distributed.barrier()
def save_checkpoint(self, path: str):
path = self.local_mkdir(path)
dist.barrier()
# every rank will save its own model and optim shard
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
model_state_dict = self.model.state_dict()
if self.optimizer is not None:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
if self.lr_scheduler is not None:
lr_scheduler_state_dict = self.lr_scheduler.state_dict()
else:
......@@ -112,29 +104,28 @@ class FSDPCheckpointManager(BaseCheckpointManager):
"lr_scheduler": lr_scheduler_state_dict,
"rng": self.get_rng_state(),
}
model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}")
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
torch.save(model_state_dict, model_path)
torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None
if self.optimizer is not None:
torch.save(optimizer_state_dict, optim_path)
torch.save(extra_state_dict, extra_path)
# wait for everyone to dump to local
torch.distributed.barrier()
dist.barrier()
if self.rank == 0:
hf_local_path = os.path.join(local_path, "huggingface")
os.makedirs(hf_local_path, exist_ok=True)
self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path)
if self.processor:
self.processor.save_pretrained(hf_local_path)
else:
self.tokenizer.save_pretrained(hf_local_path)
torch.distributed.barrier()
self.previous_save_local_path = local_path
hf_path = os.path.join(path, "huggingface")
os.makedirs(hf_path, exist_ok=True)
assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel)
self.model._fsdp_wrapped_module.config.save_pretrained(hf_path)
self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path)
self.processing_class.save_pretrained(hf_path)
dist.barrier()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from collections import defaultdict
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from PIL.Image import Image as ImageObject
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..models.transformers.qwen2_vl import get_rope_index
from . import torch_functional as VF
def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
tensors = defaultdict(list)
non_tensors = defaultdict(list)
for feature in features:
for key, value in feature.items():
if isinstance(value, torch.Tensor):
tensors[key].append(value)
else:
non_tensors[key].append(value)
for key, value in tensors.items():
tensors[key] = torch.stack(value, dim=0)
for key, value in non_tensors.items():
non_tensors[key] = np.array(value, dtype=object)
return {**tensors, **non_tensors}
class ImageProcessMixin:
max_pixels: int
min_pixels: int
def process_image(self, image: Union[Dict[str, Any], ImageObject]) -> ImageObject:
if isinstance(image, dict):
image = Image.open(BytesIO(image["bytes"]))
elif isinstance(image, bytes):
image = Image.open(BytesIO(image))
if (image.width * image.height) > self.max_pixels:
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if (image.width * image.height) < self.min_pixels:
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
return image
class RLHFDataset(Dataset, ImageProcessMixin):
"""
We assume the dataset contains a column that contains prompts and other information
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin],
prompt_key: str = "prompt",
answer_key: str = "answer",
image_key: str = "images",
max_prompt_length: int = 1024,
truncation: str = "error",
system_prompt: str = None,
max_pixels: int = None,
min_pixels: int = None,
):
self.tokenizer = tokenizer
self.processor = processor
self.prompt_key = prompt_key
self.answer_key = answer_key
self.image_key = image_key
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.system_prompt = system_prompt
self.max_pixels = max_pixels
self.min_pixels = min_pixels
if "@" in data_path:
data_path, data_split = data_path.split("@")
else:
data_split = "train"
if os.path.isdir(data_path):
self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
elif os.path.isfile(data_path):
self.dataset = load_dataset("parquet", data_files=data_path, split="train")
else: # remote dataset
self.dataset = load_dataset(data_path, split=data_split)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
row_dict: dict = self.dataset[index]
prompt_str: str = row_dict[self.prompt_key]
if self.system_prompt:
prompt_str = " ".join((self.system_prompt.strip(), prompt_str))
if self.image_key in row_dict:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list = []
for i, content in enumerate(prompt_str.split("<image>")):
if i != 0:
content_list.append({"type": "image"})
if content:
content_list.append({"type": "text", "text": content})
messages = [{"role": "user", "content": content_list}]
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
row_dict["multi_modal_data"] = {"image": images}
row_dict["multi_modal_inputs"] = dict(model_inputs)
# qwen2vl mrope
position_ids = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=model_inputs["image_grid_thw"],
attention_mask=attention_mask,
) # (3, seq_length)
else:
messages = [{"role": "user", "content": prompt_str}]
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
input_ids, attention_mask, position_ids = VF.postprocess_data(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
row_dict["input_ids"] = input_ids
row_dict["attention_mask"] = attention_mask
row_dict["position_ids"] = position_ids
row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False)
row_dict["ground_truth"] = row_dict.pop(self.answer_key)
return row_dict
......@@ -12,22 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Tuple
import torch
from transformers import LlamaConfig, PretrainedConfig, Qwen2Config
VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig)
if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
def get_device_flops(unit="T"):
def unit_convert(number, level):
def get_device_flops(unit: str = "T") -> float:
def unit_convert(number: float, level: str):
units = ["B", "K", "M", "G", "T", "P"]
if number <= 0:
return number
ptr = 0
while ptr < len(units) and units[ptr] != level:
number /= 1000
ptr += 1
return number
device_name = torch.cuda.get_device_name()
......@@ -55,21 +62,24 @@ class FlopsCounter:
Example:
flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
"""
def __init__(self, config: PretrainedConfig):
if not isinstance(config, VALID_CONFIG_TYPE):
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. MFU will always be zero.")
def __init__(self, config: "LlamaConfig"):
if config.model_type not in VALID_MODLE_TYPE:
print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.")
self.estimate_func = {"qwen2": self._estimate_qwen2_flops, "llama": self._estimate_qwen2_flops}
self.estimate_func = {
"llama": self._estimate_llama_flops,
"qwen2": self._estimate_llama_flops,
"qwen2_vl": self._estimate_llama_flops,
"qwen2_5_vl": self._estimate_llama_flops,
}
self.config = config
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
return 0
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
......@@ -96,6 +106,7 @@ class FlopsCounter:
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
# all_layer & all_token fwd & bwd flops
......@@ -103,7 +114,7 @@ class FlopsCounter:
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved
def estimate_flops(self, batch_seqlens, delta_time):
def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]:
"""
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
from collections import defaultdict
from functools import partial
from typing import Callable, Union
......@@ -73,6 +74,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
for handle in model._all_handles:
if handle._offload_params:
continue
flat_param = handle.flat_param
assert (
flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()
......@@ -89,7 +91,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
@torch.no_grad()
def load_fsdp_model(model: FSDP):
def load_fsdp_model(model: FSDP, empty_cache: bool = True):
# lazy init FSDP model
_lazy_init(model, model)
assert model._is_root, "Only support root model loading to GPU"
......@@ -102,11 +104,15 @@ def load_fsdp_model(model: FSDP):
# the following still keeps id(._local_shard) != id(.data)
flat_param._local_shard = flat_param.data
if empty_cache:
gc.collect()
@torch.no_grad()
def offload_fsdp_optimizer(optimizer: Optimizer):
def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
if not optimizer.state:
return
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
......@@ -114,14 +120,21 @@ def offload_fsdp_optimizer(optimizer: Optimizer):
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
if empty_cache:
torch.cuda.empty_cache()
@torch.no_grad()
def load_fsdp_optimizer(optimizer: Optimizer):
def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
if not optimizer.state:
return
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cuda", non_blocking=True)
if empty_cache:
gc.collect()
......@@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .logger import Tracker
__all__ = ["Tracker"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Tuple
from ..py_functional import is_package_available
if is_package_available("wandb"):
import wandb # type: ignore
if is_package_available("swanlab"):
import swanlab # type: ignore
@dataclass
class GenerationLogger(ABC):
@abstractmethod
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None: ...
@dataclass
class ConsoleGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
for inp, out, score in samples:
print(f"[prompt] {inp}\n[output] {out}\n[score] {score}\n")
@dataclass
class WandbGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
# Create column names for all samples
columns = ["step"] + sum(
[[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
)
if not hasattr(self, "validation_table"):
# Initialize the table on first call
self.validation_table = wandb.Table(columns=columns)
# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self.validation_table.data)
# Add new row with all data
row_data = [step]
for sample in samples:
row_data.extend(sample)
new_table.add_data(*row_data)
wandb.log({"val/generations": new_table}, step=step)
self.validation_table = new_table
@dataclass
class SwanlabGenerationLogger(GenerationLogger):
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
swanlab_text_list = []
for i, sample in enumerate(samples):
row_text = f"input: {sample[0]}\n\n---\n\noutput: {sample[1]}\n\n---\n\nscore: {sample[2]}"
swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}"))
swanlab.log({"val/generations": swanlab_text_list}, step=step)
GEN_LOGGERS = {
"console": ConsoleGenerationLogger,
"wandb": WandbGenerationLogger,
"swanlab": SwanlabGenerationLogger,
}
@dataclass
class AggregateGenerationsLogger:
def __init__(self, loggers: List[str]):
self.loggers: List[GenerationLogger] = []
for logger in loggers:
if logger in GEN_LOGGERS:
self.loggers.append(GEN_LOGGERS[logger]())
def log(self, samples: List[Tuple[str, str, float]], step: int) -> None:
for logger in self.loggers:
logger.log(samples, step)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A unified tracking interface that supports logging data to different backend
"""
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from torch.utils.tensorboard import SummaryWriter
from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict
from .gen_logger import AggregateGenerationsLogger
if is_package_available("mlflow"):
import mlflow # type: ignore
if is_package_available("wandb"):
import wandb # type: ignore
if is_package_available("swanlab"):
import swanlab # type: ignore
class Logger(ABC):
@abstractmethod
def __init__(self, config: Dict[str, Any]) -> None: ...
@abstractmethod
def log(self, data: Dict[str, Any], step: int) -> None: ...
def finish(self) -> None:
pass
class ConsoleLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
print("Config\n" + convert_dict_to_str(config))
def log(self, data: Dict[str, Any], step: int) -> None:
print(f"Step {step}\n" + convert_dict_to_str(unflatten_dict(data)))
class MlflowLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
mlflow.start_run(run_name=config["trainer"]["experiment_name"])
mlflow.log_params(flatten_dict(config))
def log(self, data: Dict[str, Any], step: int) -> None:
mlflow.log_metrics(metrics=data, step=step)
class TensorBoardLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
tensorboard_dir = os.getenv("TENSORBOARD_DIR", "tensorboard_log")
os.makedirs(tensorboard_dir, exist_ok=True)
print(f"Saving tensorboard log to {tensorboard_dir}.")
self.writer = SummaryWriter(tensorboard_dir)
self.writer.add_hparams(flatten_dict(config))
def log(self, data: Dict[str, Any], step: int) -> None:
for key, value in data.items():
self.writer.add_scalar(key, value, step)
def finish(self):
self.writer.close()
class WandbLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
wandb.init(
project=config["trainer"]["project_name"],
name=config["trainer"]["experiment_name"],
config=config,
)
def log(self, data: Dict[str, Any], step: int) -> None:
wandb.log(data=data, step=step)
def finish(self) -> None:
wandb.finish()
class SwanlabLogger(Logger):
def __init__(self, config: Dict[str, Any]) -> None:
swanlab_key = os.getenv("SWANLAB_API_KEY")
swanlab_dir = os.getenv("SWANLAB_DIR", "swanlab_log")
swanlab_mode = os.getenv("SWANLAB_MODE", "cloud")
if swanlab_key:
swanlab.login(swanlab_key)
swanlab.init(
project=config["trainer"]["project_name"],
experiment_name=config["trainer"]["experiment_name"],
config={"UPPERFRAMEWORK": "EasyR1", "FRAMEWORK": "veRL", **config},
logdir=swanlab_dir,
mode=swanlab_mode,
)
def log(self, data: Dict[str, Any], step: int) -> None:
swanlab.log(data=data, step=step)
def finish(self) -> None:
swanlab.finish()
LOGGERS = {
"wandb": WandbLogger,
"mlflow": MlflowLogger,
"tensorboard": TensorBoardLogger,
"console": ConsoleLogger,
"swanlab": SwanlabLogger,
}
class Tracker:
def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None):
if isinstance(loggers, str):
loggers = [loggers]
self.loggers: List[Logger] = []
for logger in loggers:
if logger not in LOGGERS:
raise ValueError(f"{logger} is not supported.")
self.loggers.append(LOGGERS[logger](config))
self.gen_logger = AggregateGenerationsLogger(loggers)
def log(self, data: Dict[str, Any], step: int) -> None:
for logger in self.loggers:
logger.log(data=data, step=step)
def log_generation(self, samples: List[Tuple[str, str, float]], step: int) -> None:
self.gen_logger.log(samples, step)
def __del__(self):
for logger in self.loggers:
logger.finish()
......@@ -15,11 +15,28 @@
Utilities to create common models
"""
from functools import lru_cache
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
def get_model_size(model: nn.Module, scale="auto"):
@lru_cache
def is_rank0() -> int:
return (not dist.is_initialized()) or (dist.get_rank() == 0)
def print_gpu_memory_usage(prefix: str = "GPU memory usage") -> None:
"""Report the current GPU VRAM usage."""
if is_rank0():
free_mem, total_mem = torch.cuda.mem_get_info()
print(f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB.")
def _get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]:
"""Compute the model size."""
n_params = sum(p.numel() for p in model.parameters())
if scale == "auto":
......@@ -41,18 +58,16 @@ def get_model_size(model: nn.Module, scale="auto"):
elif scale == "":
pass
else:
raise NotImplementedError(f"Unknown scale {scale}")
raise NotImplementedError(f"Unknown scale {scale}.")
return n_params, scale
def print_model_size(model: nn.Module, name: str = None):
n_params, scale = get_model_size(model, scale="auto")
if name is None:
name = model.__class__.__name__
print(f"{name} contains {n_params:.2f}{scale} parameters")
def print_model_size(model: nn.Module, name: Optional[str] = None) -> None:
"""Print the model size."""
if is_rank0():
n_params, scale = _get_model_size(model, scale="auto")
if name is None:
name = model.__class__.__name__
def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
print(f"{name} contains {n_params:.2f}{scale} parameters.")
......@@ -15,23 +15,89 @@
Contain small python utility functions
"""
from typing import Any, Dict, List
import importlib.util
import re
from functools import lru_cache
from typing import Any, Dict, List, Union
import numpy as np
import yaml
from yaml import Dumper
def is_sci_notation(number: float) -> bool:
pattern = re.compile(r"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$")
return bool(pattern.match(str(number)))
def float_representer(dumper: Dumper, number: Union[float, np.float32, np.float64]):
if is_sci_notation(number):
value = str(number)
if "." not in value and "e" in value:
value = value.replace("e", ".0e", 1)
else:
value = str(round(number, 3))
return dumper.represent_scalar("tag:yaml.org,2002:float", value)
yaml.add_representer(float, float_representer)
yaml.add_representer(np.float32, float_representer)
yaml.add_representer(np.float64, float_representer)
@lru_cache
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
"""Union two dict. Will throw an error if there is an item not the same object with the same key."""
for key, value in dict2.items():
for key in dict2.keys():
if key in dict1:
assert dict1[key] != value, f"{key} in meta_dict1 and meta_dict2 are not the same object"
assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object"
dict1[key] = value
dict1[key] = dict2[key]
return dict1
def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None:
"""Append dict to a dict of list."""
for key, val in new_data.items():
if key not in data:
data[key] = []
data[key].append(val)
def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]:
unflattened = {}
for key, value in data.items():
pieces = key.split(sep)
pointer = unflattened
for piece in pieces[:-1]:
if piece not in pointer:
pointer[piece] = {}
pointer = pointer[piece]
pointer[pieces[-1]] = value
return unflattened
def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]:
flattened = {}
for key, value in data.items():
new_key = parent_key + sep + key if parent_key else key
if isinstance(value, dict):
flattened.update(flatten_dict(value, new_key, sep=sep))
else:
flattened[new_key] = value
return flattened
def convert_dict_to_str(data: Dict[str, Any]) -> str:
return yaml.dump(data, indent=2)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict
from mathruler.grader import extract_boxed_content, grade_answer
def math_compute_score(predict_str: str, ground_truth: str) -> float:
def math_format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0
def math_acc_reward(predict_str: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict_str)
if answer == "None":
return 0.0 # no answer
return 1.0 if grade_answer(answer, ground_truth) else 0.0
if grade_answer(answer, ground_truth):
return 1.0 # correct answer
return 0.1 # wrong answer
def math_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
format = math_format_reward(predict_str)
accuracy = math_acc_reward(predict_str, ground_truth)
return {
"overall": 0.9 * accuracy + 0.1 * format,
"format": format,
"accuracy": accuracy,
}
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict
from mathruler.grader import grade_answer
def r1v_format_reward(predict_str: str) -> float:
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
match = re.fullmatch(pattern, predict_str, re.DOTALL)
return 1.0 if match else 0.0
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0
def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
try:
ground_truth = ground_truth.strip()
content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
pred_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(pred_answer, ground_truth):
given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(given_answer, ground_truth):
return 1.0
except Exception:
pass
return 0.0
def r1v_compute_score(predict_str: str, ground_truth: str) -> float:
acc_reward = r1v_accuracy_reward(predict_str, ground_truth)
format_reward = r1v_format_reward(predict_str)
reward = acc_reward + format_reward
reward /= 2
return reward
def r1v_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
format = r1v_format_reward(predict_str)
accuracy = r1v_accuracy_reward(predict_str, ground_truth)
return {
"overall": 0.5 * accuracy + 0.5 * format,
"format": format,
"accuracy": accuracy,
}
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import heapq
from typing import List, Tuple
import torch
from tensordict import TensorDict
from torch import distributed as dist
class Set:
def __init__(self) -> None:
self.sum = 0
self.items = []
def add(self, idx: int, val: int):
self.items.append((idx, val))
self.sum += val
def merge(self, other):
for idx, val in other.items:
self.items.append((idx, val))
self.sum += val
def __lt__(self, other):
if self.sum != other.sum:
return self.sum < other.sum
if len(self.items) != len(other.items):
return len(self.items) < len(other.items)
return self.items < other.items
class State:
def __init__(self, items: List[Tuple[int, int]], k: int) -> None:
self.k = k
# sets should always be decreasing order
self.sets = [Set() for _ in range(k)]
assert len(items) in [1, k], f"{len(items)} not in [1, {k}]"
for i, (idx, seqlen) in enumerate(items):
self.sets[i].add(idx=idx, val=seqlen)
self.sets = sorted(self.sets, reverse=True)
def get_partitions(self):
partitions = []
for i in range(len(self.sets)):
cur_partition = []
for idx, _ in self.sets[i].items:
cur_partition.append(idx)
partitions.append(cur_partition)
return partitions
def merge(self, other):
for i in range(self.k):
self.sets[i].merge(other.sets[self.k - 1 - i])
self.sets = sorted(self.sets, reverse=True)
@property
def spread(self) -> int:
return self.sets[0].sum - self.sets[-1].sum
def __lt__(self, other):
# least heap, let the state with largest spread to be popped first,
# if the spread is the same, let the state who has the largest set
# to be popped first.
if self.spread != other.spread:
return self.spread > other.spread
return self.sets[0] > other.sets[0]
def __repr__(self) -> str:
repr_str = "["
for i in range(self.k):
if i > 0:
repr_str += ","
repr_str += "{"
for j, (_, seqlen) in enumerate(self.sets[i].items):
if j > 0:
repr_str += ","
repr_str += str(seqlen)
repr_str += "}"
repr_str += "]"
return repr_str
def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)])
states_pq: List[State] = []
if equal_size:
assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0"
for offset in range(0, len(sorted_seqlen_list), k_partitions):
items = []
for i in range(k_partitions):
seqlen, idx = sorted_seqlen_list[offset + i]
items.append((idx, seqlen))
heapq.heappush(states_pq, State(items=items, k=k_partitions))
else:
for seqlen, idx in sorted_seqlen_list:
heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions))
while len(states_pq) > 1:
state0 = heapq.heappop(states_pq)
state1 = heapq.heappop(states_pq)
# merge states
state0.merge(state1)
heapq.heappush(states_pq, state0)
final_state = states_pq[0]
partitions = final_state.get_partitions()
if equal_size:
for i, partition in enumerate(partitions):
assert len(partition) * k_partitions == len(seqlen_list), (
f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
)
return partitions
def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool):
bias = sum(seqlen_list) + 1 if equal_size else 0
sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)]
partitions = [[] for _ in range(k_partitions)]
partition_sums = [0 for _ in range(k_partitions)]
for seqlen, i in sorted_seqlen:
min_idx = None
for j in range(k_partitions):
if min_idx is None or partition_sums[j] < partition_sums[min_idx]:
min_idx = j
partitions[min_idx].append(i)
partition_sums[min_idx] += seqlen
if equal_size:
for i, partition in enumerate(partitions):
assert len(partition) * k_partitions == len(seqlen_list), (
f"{len(partition)} * {k_partitions} != {len(seqlen_list)}"
)
return partitions
def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool):
"""get order of seq lengths to make partitions balanced, this is
used in balacing sum of seqlength across dp ranks and microbatches
Parameters:
seqlen_list (List[int]):
seq lengths of each items
k_partitions (int):
resulting number of partitions
equal_size (bool):
if True, number of items in each partitions must be equal.
if False, only consider balancing the sum, each partition can have
variable number of items
Returns:
partitions (List[List[int]]):
return k_partitions list containing the index of items.
"""
assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
def _check_and_sort_partitions(partitions):
assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}"
seen_idx = set()
sorted_partitions = [None] * k_partitions
for i, partition in enumerate(partitions):
assert len(partition) > 0, f"the {i}-th partition is empty"
for idx in partition:
seen_idx.add(idx)
sorted_partitions[i] = sorted(partition)
assert seen_idx == set(range(len(seqlen_list)))
return sorted_partitions
partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
return _check_and_sort_partitions(partitions)
def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix):
# add some metrics of seqlen sum on dp ranks
k_partition = len(partitions)
# assert len(seqlen_list) % k_partition == 0
batch_size = len(seqlen_list) // k_partition
min_sum_seqlen = None
max_sum_seqlen = None
total_sum_seqlen = 0
for offset in range(0, len(seqlen_list), batch_size):
cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])
if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
min_sum_seqlen = cur_sum_seqlen
if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen:
max_sum_seqlen = cur_sum_seqlen
total_sum_seqlen += cur_sum_seqlen
balanced_sum_seqlen_list = []
for partition in partitions:
cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])
balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)
# print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list)
min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)
max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)
return {
f"{prefix}/min": min_sum_seqlen,
f"{prefix}/max": max_sum_seqlen,
f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen,
f"{prefix}/balanced_min": min_sum_seqlen_balanced,
f"{prefix}/balanced_max": max_sum_seqlen_balanced,
f"{prefix}/mean": total_sum_seqlen / len(partitions),
}
def ceildiv(a, b):
return -(a // -b)
def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
"""Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len
and the number of valid tokens in each micro batch is well balanced.
"""
# this is per local micro_bsz
max_seq_len = batch["attention_mask"].shape[-1]
assert max_token_len >= max_seq_len, (
f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}"
)
seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
total_seqlen = seq_len_effective.sum().item()
num_micro_batches = ceildiv(total_seqlen, max_token_len)
if dist.is_initialized():
num_micro_batches = torch.tensor([num_micro_batches], device="cuda")
dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
num_micro_batches = num_micro_batches.cpu().item()
seq_len_effective = seq_len_effective.tolist()
assert num_micro_batches <= len(seq_len_effective)
micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)
micro_batches = []
for partition in micro_bsz_idx:
curr_micro_batch = []
for idx in partition:
curr_micro_batch.append(batch[idx : idx + 1])
curr_micro_batch = torch.cat(curr_micro_batch)
micro_batches.append(curr_micro_batch)
return micro_batches, micro_bsz_idx
def get_reverse_idx(idx_map):
reverse_idx_map = copy.deepcopy(idx_map)
for i, idx in enumerate(idx_map):
reverse_idx_map[idx] = i
return reverse_idx_map
......@@ -15,38 +15,28 @@
from typing import Optional
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
def get_tokenizer(model_path, correct_pad_token=True, correct_gemma=True, **kwargs) -> PreTrainedTokenizer:
"""Create a huggingface pretrained tokenizer.
Args:
name (str): The name of the tokenizer.
correct_pad_token (bool): Whether to correct the pad token id.
correct_gemma (bool): Whether to correct the gemma tokenizer.
**kwargs: The keyword arguments for the tokenizer.
Returns:
transformers.PreTrainedTokenizer: The pretrained tokenizer.
"""
config = AutoConfig.from_pretrained(model_path)
def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer:
"""Create a huggingface pretrained tokenizer."""
tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
if correct_gemma and getattr(config, "model_type", None) in ["gemma", "gemma2"]:
# the EOS token in gemma2 is ambiguious, which may worsen RL performance.
if tokenizer.bos_token == "<bos>" and tokenizer.eos_token == "<eos>":
# the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance.
# https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
print("Found gemma model. Set eos_token and eos_token_id to <end_of_turn> and 107.")
tokenizer.eos_token = "<end_of_turn>"
if correct_pad_token:
if tokenizer.pad_token_id is None:
print("Pad token is None. Set it to eos_token.")
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def get_processor(model_path, **kwargs) -> Optional[ProcessorMixin]:
def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]:
"""Create a huggingface pretrained processor."""
try:
processor = AutoProcessor.from_pretrained(model_path, **kwargs)
except Exception:
......
......@@ -48,7 +48,7 @@ class PrecisionType:
return precision in BFLOAT_LIST
@staticmethod
def to_dtype(precision):
def to_dtype(precision) -> torch.dtype:
if precision in HALF_LIST:
return torch.float16
elif precision in FLOAT_LIST:
......@@ -59,12 +59,12 @@ class PrecisionType:
raise RuntimeError(f"unexpected precision: {precision}")
@staticmethod
def to_str(precision):
def to_str(precision: torch.dtype) -> str:
if precision == torch.float16:
return "fp16"
return "float16"
elif precision == torch.float32:
return "fp32"
return "float32"
elif precision == torch.bfloat16:
return "bf16"
return "bfloat16"
else:
raise RuntimeError(f"unexpected precision: {precision}")
......@@ -15,15 +15,12 @@
Contain small torch utilities
"""
import math
from typing import List, Literal, Union
from typing import List, Literal, Optional, Tuple, Union
import torch
import torch.distributed
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from transformers import PreTrainedTokenizer
try:
......@@ -34,113 +31,85 @@ except ImportError:
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False
def logprobs_from_logits(logits, labels):
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:
batch_dim = logits.shape[:-1]
last_dim = logits.shape[-1]
logits = logits.reshape(-1, last_dim)
labels = labels.reshape(-1)
output = logprobs_from_logits_flash_attn(logits, labels)
output = output.view(*batch_dim)
else:
output = logprobs_from_logits_v2(logits, labels)
return output
@torch.compiler.disable()
def log_probs_from_logits_flash_attn(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
output = cross_entropy_loss(logits, labels, inplace_backward=True)
if not isinstance(output, tuple):
raise ValueError(
"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
)
def logprobs_from_logits_flash_attn(logits, labels):
output = cross_entropy_loss(logits, labels)
assert isinstance(output, tuple), (
"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
)
return -output[0]
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
"""
A memory efficient implementation of logprobs_from_logits
"""
if logits.dtype in [torch.float32, torch.float64]:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
logprobs_labels = []
for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption
row_logprobs = F.log_softmax(row_logits, dim=-1)
row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
logprobs_labels.append(row_logprobs_labels)
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Compute log probs on the label ids given logits.
logprobs_labels = torch.stack(logprobs_labels)
return logprobs_labels
We may use torch compile to speed up computing.
Args:
logits (torch.Tensor): logits of the model, shape (batch_size, seqlen, vocab_size)
labels (torch.Tensor): labels of the model, shape (batch_size, seqlen)
def clip_by_value(x, tensor_min, tensor_max):
"""
Tensor extenstion to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
Returns:
torch.Tensor: log probs of the labels, shape (batch_size, seqlen)
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped
batch_dim = logits.shape[:-1]
vocab_dim = logits.shape[-1]
logits = logits.contiguous().view(-1, vocab_dim)
labels = labels.contiguous().view(-1)
if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:
output = log_probs_from_logits_flash_attn(logits, labels)
else: # fall back to torch kernel, upcast logits to fp32
output = F.cross_entropy(logits.float(), labels, reduction="none")
def entropy_from_logits(logits: torch.Tensor):
"""Calculate entropy from logits."""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
return entropy
return output.view(*batch_dim)
def masked_mean(values, mask, axis=None) -> torch.Tensor:
def masked_mean(values: torch.Tensor, mask: torch.Tensor, dim: int = None, eps: float = 1e-8) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + eps)
def masked_var(values, mask, unbiased=True):
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError("At least one element in the mask has to be 1.")
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
if mask_sum == 1:
raise ValueError("The sum of the mask is one, which can cause a division by zero.")
if mask_sum <= 1:
print("The sum of the mask is less than one, which can cause a division by zero.")
return variance
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance
def masked_whiten(values, mask, shift_mean=True):
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
return (values - mean) * torch.rsqrt(var + eps)
def get_eos_mask(response_ids: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
"""
end of sentence token can be int or list: 1 or [1, 2]
e.g. eos_token=1
response_ids: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
def get_eos_mask(response_ids: torch.Tensor, eos_token_id: Union[int, List[int]] = 2, dtype: torch.dtype = torch.long):
"""Get the mask for the response ids, the mask will be 0 after the first eos token.
eos_token_id can be int or list: 1 or [1, 2].
```
e.g. eos_token = 1
response_ids: [0, 0, 2, 4, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
```
"""
if isinstance(eos_token, int):
eos_token = [eos_token]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_mask = torch.zeros_like(response_ids, dtype=torch.bool)
for token in eos_token:
eos_mask |= response_ids.eq(token)
for token_id in eos_token_id:
eos_mask |= response_ids.eq(token_id)
eos_mask = eos_mask.long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
......@@ -148,151 +117,211 @@ def get_eos_mask(response_ids: torch.Tensor, eos_token: Union[int, List[int]] =
return eos_mask
def pad_2d_list_to_length(response, pad_token_id, max_length=None) -> torch.Tensor:
"""
pad a 2D list (e.g. responses, logprobs) to a 2D tensor.
"""
response_length = max(len(sub_list) for sub_list in response)
if max_length is not None and max_length > response_length:
def pad_2d_list_to_length(
response: List[List[int]], pad_token_id: int, max_length: Optional[int] = None
) -> torch.Tensor:
"""Pad a 2D list (e.g. responses, log_probs) to a 2D tensor."""
max_response_length = max(len(sub_list) for sub_list in response)
if max_length is not None and max_length > max_response_length:
target_length = max_length
else:
target_length = response_length
target_length = max_response_length
padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response]
tensor = torch.tensor(padded_response)
return tensor
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
"""
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
input shape: [bs, seq_length]
output shape: [bs, max_seq_length]
(0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad
"""
if tensors.shape[-1] >= max_seq_len:
return tensors
def pad_sequence_to_length(
tensor: torch.Tensor, max_seq_len: int, pad_token_id: int, left_pad: bool = False
) -> torch.Tensor:
"""Pad a nD tensors in the last dim to max_seq_len."""
if tensor.size(-1) >= max_seq_len:
return tensor
pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1])
return F.pad(tensors, pad_tuple, "constant", pad_token_id)
pad_shape = list(tensor.shape)
pad_shape[-1] = max_seq_len - tensor.size(-1)
pad_tensor = torch.full(pad_shape, fill_value=pad_token_id, dtype=tensor.dtype, device=tensor.device)
return torch.cat((pad_tensor, tensor), dim=-1) if left_pad else torch.cat((tensor, pad_tensor), dim=-1)
def tokenize_and_postprocess_data(
prompt: str,
tokenizer: PreTrainedTokenizer,
def postprocess_data(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
max_length: int,
pad_token_id: int,
left_pad: bool = True,
truncation: Literal["left", "right", "error"] = "error",
):
"""
input_data is the output from tokenizer.
"""
"""Pad or truncate data."""
assert truncation in ["left", "right", "error"]
input_data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = input_data["input_ids"][0]
attention_mask = input_data["attention_mask"][0]
sequence_length = len(input_ids)
if sequence_length < max_length:
seq_length = len(input_ids)
if seq_length < max_length:
input_ids = pad_sequence_to_length(
input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad
)
attention_mask = pad_sequence_to_length(
attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad
)
elif sequence_length > max_length:
if truncation == "left":
# actually, left truncation may not be reasonable
input_ids = input_ids[-max_length:]
attention_mask = attention_mask[-max_length:]
position_ids = pad_sequence_to_length(position_ids, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad)
elif seq_length > max_length:
if truncation == "left": # actually, left truncation may not be reasonable
input_ids = input_ids[..., -max_length:]
attention_mask = attention_mask[..., -max_length:]
position_ids = position_ids[..., -max_length:]
elif truncation == "right":
input_ids = input_ids[:max_length]
attention_mask = attention_mask[:max_length]
input_ids = input_ids[..., :max_length]
attention_mask = attention_mask[..., :max_length]
position_ids = position_ids[..., :max_length]
elif truncation == "error":
raise NotImplementedError(f"{sequence_length=} is larger than {max_length=}")
raise NotImplementedError(f"{seq_length} is larger than {max_length}.")
else:
raise NotImplementedError(f"Unknown truncation method {truncation}")
return input_ids, attention_mask
raise NotImplementedError(f"Unknown truncation method {truncation}.")
def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):
"""Remove the pad token.
Args:
input_ids shape: [bs, seq_length]
attention_mask shape: [bs, seq_length]
Returns:
no_padding_batch(List[List[int]]): contains the rmpad token ids per query.
"""
no_padding_batch = []
for ids, mask in zip(input_ids, attention_mask):
no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist())
return no_padding_batch
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
coef = (1 - min_lr_ratio) * 0.5
intercept = (1 + min_lr_ratio) * 0.5
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
return max(0.0, x * coef + intercept)
return LambdaLR(optimizer, lr_lambda, last_epoch)
return input_ids, attention_mask, position_ids
def get_constant_schedule_with_warmup(
optimizer: Optimizer,
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
last_epoch: int = -1,
):
def lr_lambda(current_step):
return min(1, float(current_step) / float(max(1, num_warmup_steps)))
) -> torch.optim.lr_scheduler.LRScheduler:
"""Get the lr scheduler for constant lr."""
def lr_lambda(current_step: int) -> float:
return min(1.0, float(current_step) / float(max(1, num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# https://github.com/meta-llama/llama-cookbook/blob/v0.0.5/src/llama_cookbook/policies/anyprecision_optimizer.py
class AnyPrecisionAdamW(torch.optim.Optimizer):
def __init__(
self,
params: List[torch.Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
use_kahan_summation: bool = True,
momentum_dtype: torch.dtype = torch.bfloat16,
variance_dtype: torch.dtype = torch.bfloat16,
compensation_buffer_dtype: torch.dtype = torch.bfloat16,
):
"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
# Any Precision specific
use_kahan_summation = creates auxiliary buffer to ensure high precision
model param updates (default: False)
momentum_dtype = dtype for momentum (default: bfloat16)
variance_dtype = dtype for uncentered variance (default: bfloat16)
compensation_buffer_dtype = dtype for Kahan summation buffer (default: bfloat16)
# Usage
This optimizer implements optimizer states, and Kahan summation
for high precision updates, all in user controlled dtypes.
Defaults are variance in BF16, Momentum in FP32.
This can be run in FSDP mixed precision, amp, or full precision,
depending on what training pipeline you wish to work with.
Setting to use_kahan_summation = False, and changing momentum and
variance dtypes to FP32, reverts this to a standard AdamW optimizer.
"""
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"use_kahan_summation": use_kahan_summation,
"momentum_dtype": momentum_dtype,
"variance_dtype": variance_dtype,
"compensation_buffer_dtype": compensation_buffer_dtype,
}
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
if closure is not None:
with torch.enable_grad():
closure()
for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
weight_decay = group["weight_decay"]
eps = group["eps"]
use_kahan_summation = group["use_kahan_summation"]
momentum_dtype = group["momentum_dtype"]
variance_dtype = group["variance_dtype"]
compensation_buffer_dtype = group["compensation_buffer_dtype"]
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("AnyPrecisionAdamW does not support sparse gradients.")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0)
# momentum - EMA of gradient values
state["exp_avg"] = torch.zeros_like(p, dtype=momentum_dtype)
# variance uncentered - EMA of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, dtype=variance_dtype)
# optional Kahan summation - accumulated error tracker
if use_kahan_summation:
state["compensation"] = torch.zeros_like(p, dtype=compensation_buffer_dtype)
# Main processing
# update the steps for each param group update
state["step"] += 1
step = state["step"]
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
grad = p.grad
if weight_decay: # weight decay, AdamW style
p.data.mul_(1 - lr * weight_decay)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # update momentum
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # update uncentered variance
bias_correction1 = 1 - beta1**step # adjust using bias1
step_size = lr / bias_correction1
denom_correction = (1 - beta2**step) ** 0.5 # adjust using bias2 and avoids math import
centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(eps, alpha=1)
if use_kahan_summation: # lr update to compensation
compensation = state["compensation"]
compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
# update weights with compensation (Kahan summation)
# save error back to compensation for next iteration
temp_buffer = p.detach().clone()
p.data.add_(compensation)
compensation.add_(temp_buffer.sub_(p.data))
else: # usual AdamW updates
p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
......@@ -238,7 +238,7 @@ class Gather(torch.autograd.Function):
)
def gather_outpus_and_unpad(
def gather_outputs_and_unpad(
x: Tensor,
gather_dim: int,
unpad_dim: int = None,
......
......@@ -20,8 +20,8 @@ from typing import Any, Dict
import torch
from verl import DataProto
from verl.workers.actor.config import ActorConfig
from ...protocol import DataProto
from .config import ActorConfig
__all__ = ["BasePPOActor"]
......
......@@ -26,6 +26,7 @@ class ModelConfig:
override_config: Dict[str, Any] = field(default_factory=dict)
enable_gradient_checkpointing: bool = True
trust_remote_code: bool = True
freeze_vision_tower: bool = False
def post_init(self):
if self.tokenizer_path is None:
......@@ -37,7 +38,8 @@ class OptimConfig:
lr: float = 1e-6
betas: Tuple[float, float] = (0.9, 0.999)
weight_decay: float = 1e-2
lr_warmup_steps_ratio: float = 0.0
strategy: str = "adamw"
lr_warmup_ratio: float = 0.0
min_lr_ratio: Optional[float] = None
warmup_style: str = "constant"
"""auto keys"""
......@@ -47,9 +49,11 @@ class OptimConfig:
@dataclass
class FSDPConfig:
enable_full_shard: bool = True
param_offload: bool = False
optimizer_offload: bool = False
enable_cpu_offload: bool = False
enable_rank0_init: bool = False
use_orig_params: bool = False
torch_dtype: Optional[str] = None
fsdp_size: int = -1
mp_param_dtype: str = "bf16"
mp_reduce_dtype: str = "fp32"
mp_buffer_dtype: str = "fp32"
......@@ -57,41 +61,41 @@ class FSDPConfig:
@dataclass
class OffloadConfig:
param_offload: bool = False
optimizer_offload: bool = False
offload_params: bool = False
offload_optimizer: bool = False
@dataclass
class ActorConfig:
strategy: str = "fsdp"
global_batch_size: int = 256
micro_batch_size_per_device_for_update: int = field(default=-1, init=False)
micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
micro_batch_size_per_device_for_update: int = 4
micro_batch_size_per_device_for_experience: int = 16
max_grad_norm: float = 1.0
clip_ratio: float = 0.2
entropy_coeff: float = 1e-3
use_kl_loss: bool = True
kl_loss_coef: float = 1e-3
kl_loss_type: str = "low_var_kl"
ppo_epochs: int = 1
padding_free: bool = False
ulysses_sequence_parallel_size: int = 1
use_torch_compile: bool = True
model: ModelConfig = field(default_factory=ModelConfig)
optim: OptimConfig = field(default_factory=OptimConfig)
fsdp: FSDPConfig = field(default_factory=FSDPConfig)
offload: OffloadConfig = field(default_factory=OffloadConfig)
"""auto keys"""
global_batch_size_per_device: int = field(default=-1, init=False)
def post_init(self):
if self.ppo_epochs != 1:
raise NotImplementedError
disable_kl: bool = field(default=False, init=False)
use_kl_loss: bool = field(default=False, init=False)
kl_penalty: str = field(default="kl", init=False)
kl_coef: float = field(default=0.0, init=False)
@dataclass
class RefConfig:
strategy: str = "fsdp"
fsdp: FSDPConfig = field(default_factory=FSDPConfig)
offload: OffloadConfig = field(default_factory=OffloadConfig)
"""auto keys"""
micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
padding_free: bool = field(default=False, init=False)
ulysses_sequence_parallel_size: int = field(default=1, init=False)
use_torch_compile: bool = field(default=True, init=False)
......@@ -17,20 +17,26 @@ Implement Actor
import os
from collections import defaultdict
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import torch
from ray.experimental.tqdm_ray import tqdm
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from tqdm import tqdm
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer import core_algos
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, masked_mean
from verl.workers.actor.base import BasePPOActor
from verl.workers.actor.config import ActorConfig
from ...protocol import DataProto
from ...trainer import core_algos
from ...utils import torch_functional as VF
from ...utils.py_functional import append_to_dict
from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
from .base import BasePPOActor
from .config import ActorConfig
try:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
except ImportError:
pass
__all__ = ["DataParallelPPOActor"]
......@@ -50,17 +56,18 @@ class DataParallelPPOActor(BasePPOActor):
self.rank = int(os.getenv("RANK", "0"))
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
if config.use_torch_compile:
self.log_probs_from_logits = torch.compile(VF.log_probs_from_logits, dynamic=True)
else:
self.log_probs_from_logits = VF.log_probs_from_logits
def _forward_micro_batch(
self, micro_batch: Dict[str, torch.Tensor], temperature: float
) -> Tuple[torch.Tensor, torch.Tensor]:
def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature: float) -> torch.Tensor:
"""
Returns:
entropy: # (bs, response_len)
log_probs: # (bs, response_len)
"""
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
responses = micro_batch["responses"]
......@@ -68,29 +75,82 @@ class DataParallelPPOActor(BasePPOActor):
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
vision_inputs = {}
if "pixel_values" in micro_batch:
vision_inputs["pixel_values"] = torch.cat(micro_batch["pixel_values"], dim=0)
vision_inputs["image_grid_thw"] = torch.cat(micro_batch["image_grid_thw"], dim=0)
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch:
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)
if self.config.padding_free:
# TODO (yaowei): preprocess data for padding_free and ulysses
raise NotImplementedError
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
# pad and slice the inputs if sp > 1
if self.config.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
input_ids_rmpad_rolled, None, self.config.ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.actor_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad.div_(temperature)
# ((total_nnz / sp) + pad)
log_probs = self.log_probs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
# gather log_prob if sp > 1
if self.config.ulysses_sequence_parallel_size > 1:
# gather and unpad for the ulysses sp
log_probs = gather_outputs_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# pad back to (bsz, seqlen)
full_log_probs = pad_input(
hidden_states=log_probs.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
else:
output = self.actor_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**vision_inputs,
**multi_modal_inputs,
use_cache=False,
)
logits: torch.Tensor = output.logits
logits.div_(temperature)
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
log_probs = logprobs_from_logits(logits, responses) # (bsz, response_length)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
log_probs = self.log_probs_from_logits(logits, responses) # (bsz, response_length)
return entropy, log_probs
return log_probs
def _optimizer_step(self) -> torch.Tensor:
if isinstance(self.actor_module, FSDP):
......@@ -98,7 +158,12 @@ class DataParallelPPOActor(BasePPOActor):
else:
grad_norm = nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.max_grad_norm)
self.actor_optimizer.step()
if not torch.isfinite(grad_norm):
print("Gradient norm is not finite. Skip update.")
else:
self.actor_optimizer.step()
self.actor_optimizer.zero_grad()
return grad_norm
@torch.no_grad()
......@@ -124,19 +189,21 @@ class DataParallelPPOActor(BasePPOActor):
temperature = data.meta_info["temperature"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
if "pixel_values" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"]
if "multi_modal_inputs" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["multi_modal_inputs"]
else:
non_tensor_select_keys = None
non_tensor_select_keys = []
micro_batches = data.select(select_keys, non_tensor_select_keys).split(
self.config.micro_batch_size_per_device_for_experience
)
log_probs_lst = []
for micro_batch in tqdm(micro_batches, desc="Compute log probs", disable=(self.rank != 0)):
micro_batch.to("cuda")
if self.rank == 0:
micro_batches = tqdm(micro_batches, desc="Compute log probs", position=2)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
_, log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)
log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)
log_probs_lst.append(log_probs)
log_probs = torch.concat(log_probs_lst, dim=0)
......@@ -147,83 +214,74 @@ class DataParallelPPOActor(BasePPOActor):
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.use_kl_loss and not self.config.disable_kl:
select_keys.append("ref_log_probs")
if "pixel_values" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"]
if "multi_modal_inputs" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["multi_modal_inputs"]
else:
non_tensor_select_keys = None
non_tensor_select_keys = []
# TODO (yaowei): support ppo epochs
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)
metrics = defaultdict(list)
n = len(mini_batches)
for i, mini_batch in enumerate(mini_batches):
gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
self.actor_optimizer.zero_grad()
for micro_batch in tqdm(micro_batches, desc=f"Update policy [{i + 1}/{n}]", disable=(self.rank != 0)):
micro_batch.to("cuda")
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
response_length = responses.size(1)
attention_mask = model_inputs["attention_mask"]
response_mask = attention_mask[:, -response_length:]
old_log_prob = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]
clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff
# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(model_inputs, temperature=temperature)
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio,
for _ in range(self.config.ppo_epochs):
if self.rank == 0:
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2)
for mini_batch in mini_batches:
gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
# compute entropy loss from entropy
entropy_loss = verl_F.masked_mean(entropy, response_mask)
# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff
if self.config.use_kl_loss:
ref_log_prob = model_inputs["ref_log_prob"]
# compute kl loss
kld = core_algos.kl_penalty(
logprob=log_prob,
ref_logprob=ref_log_prob,
kl_penalty=self.config.kl_loss_type,
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
if self.rank == 0:
micro_batches = tqdm(micro_batches, desc="Update policy", position=3)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
response_length = responses.size(1)
attention_mask = model_inputs["attention_mask"]
response_mask = attention_mask[:, -response_length:]
old_log_probs = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]
# all return: (bsz, response_length)
log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(
old_log_probs=old_log_probs,
log_probs=log_probs,
advantages=advantages,
eos_mask=response_mask,
cliprange=self.config.clip_ratio,
)
kl_loss = masked_mean(kld, response_mask)
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_loss_coef
loss = policy_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"actor/entropy_loss": entropy_loss.detach().item(),
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
}
append_to_dict(metrics, batch_metrics)
if "ref_log_probs" in model_inputs:
ref_log_probs = model_inputs["ref_log_probs"]
# compute kl loss
kld = core_algos.kl_penalty(
log_probs=log_probs,
ref_log_probs=ref_log_probs,
kl_penalty=self.config.kl_penalty,
)
kl_loss = VF.masked_mean(kld, response_mask)
pg_loss = pg_loss + kl_loss * self.config.kl_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_coef
loss = pg_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
}
append_to_dict(metrics, batch_metrics)
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()})
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()})
self.actor_optimizer.zero_grad()
return metrics
......@@ -17,10 +17,10 @@ ActorRolloutRef config
from dataclasses import dataclass, field
from verl.workers.actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
from verl.workers.critic import CriticConfig
from verl.workers.reward import RewardConfig
from verl.workers.rollout import RolloutConfig
from .actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
from .critic import CriticConfig
from .reward import RewardConfig
from .rollout import RolloutConfig
__all__ = [
......@@ -46,5 +46,7 @@ class WorkerConfig:
rollout: RolloutConfig = field(default_factory=RolloutConfig)
def post_init(self):
self.ref.padding_free = self.actor.padding_free
self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience
self.ref.padding_free = self.actor.padding_free
self.ref.ulysses_sequence_parallel_size = self.actor.ulysses_sequence_parallel_size
self.ref.use_torch_compile = self.actor.use_torch_compile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment