Commit 2369eb2b authored by chenych's avatar chenych
Browse files

update

parent ac9d2b05
......@@ -98,7 +98,7 @@ class RayResourcePool(ResourcePool):
# print(f"pg_name_prefix = {pg_name_prefix}")
pg_scheme = [
[
{"CPU": self.max_collocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_collocate_count}
{"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count}
for _ in range(process_count)
]
for process_count in self._store
......@@ -145,8 +145,8 @@ def extract_pg_from_exist(
def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:
assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not"
assert rp1.max_collocate_count == rp2.max_collocate_count, (
"Both RayResourcePool must has the same max_collocate_count"
assert rp1.max_colocate_count == rp2.max_colocate_count, (
"Both RayResourcePool must has the same max_colocate_count"
)
assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node"
assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool"
......@@ -259,7 +259,7 @@ class RayWorkerGroup(WorkerGroup):
world_size = resource_pool.world_size
self._world_size = world_size
# cia.add_kwarg("_world_size", world_size)
num_gpus = 1 / resource_pool.max_collocate_count
num_gpus = 1 / resource_pool.max_colocate_count
rank = -1
local_world_size = resource_pool.store[0]
......@@ -300,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
if rank == 0:
register_center_actor = None
for _ in range(360):
for _ in range(120):
if f"{self.name_prefix}_register_center" not in list_named_actors():
time.sleep(1)
else:
......
......@@ -47,6 +47,14 @@ class DataConfig:
seed: int = 1
max_pixels: int = 4194304
min_pixels: int = 262144
filter_overlong_prompts: bool = True
def post_init(self):
if self.format_prompt is not None:
if os.path.exists(self.format_prompt):
self.format_prompt = os.path.abspath(self.format_prompt)
else:
self.format_prompt = None
@dataclass
......@@ -86,6 +94,10 @@ class TrainerConfig:
if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path)
if self.load_checkpoint_path is not None:
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
@dataclass
class PPOConfig:
......@@ -97,6 +109,7 @@ class PPOConfig:
def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length
self.worker.rollout.trust_remote_code = self.worker.actor.model.trust_remote_code
self.worker.actor.disable_kl = self.algorithm.disable_kl
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
self.worker.actor.kl_penalty = self.algorithm.kl_penalty
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..utils.dataset import RLHFDataset, collate_fn
from .config import DataConfig
def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin]) -> None:
train_dataset = RLHFDataset(
data_path=config.train_files,
tokenizer=tokenizer,
processor=processor,
prompt_key=config.prompt_key,
answer_key=config.answer_key,
image_key=config.image_key,
max_prompt_length=config.max_prompt_length,
truncation="right",
format_prompt=config.format_prompt,
min_pixels=config.min_pixels,
max_pixels=config.max_pixels,
filter_overlong_prompts=config.filter_overlong_prompts,
)
# use sampler for better ckpt resume
if config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(config.seed)
sampler = RandomSampler(data_source=train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=train_dataset)
train_dataloader = StatefulDataLoader(
dataset=train_dataset,
batch_size=config.rollout_batch_size,
sampler=sampler,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=True,
)
val_dataset = RLHFDataset(
data_path=config.val_files,
tokenizer=tokenizer,
processor=processor,
prompt_key=config.prompt_key,
answer_key=config.answer_key,
image_key=config.image_key,
max_prompt_length=config.max_prompt_length,
truncation="right",
format_prompt=config.format_prompt,
min_pixels=config.min_pixels,
max_pixels=config.max_pixels,
filter_overlong_prompts=config.filter_overlong_prompts,
)
val_dataloader = StatefulDataLoader(
dataset=val_dataset,
batch_size=len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
)
assert len(train_dataloader) >= 1
assert len(val_dataloader) >= 1
print(f"Size of train dataloader: {len(train_dataloader)}")
print(f"Size of val dataloader: {len(val_dataloader)}")
return train_dataloader, val_dataloader
......@@ -11,21 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import json
import torch
import ray
from omegaconf import OmegaConf
from ..single_controller.ray import RayWorkerGroup
from ..utils.tokenizer import get_processor, get_tokenizer
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import CustomRewardManager
from ..workers.reward import FunctionRewardManager
from .config import PPOConfig
from .data_loader import create_dataloader
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
......@@ -36,7 +33,6 @@ class Runner:
def run(self, config: PPOConfig):
# print config
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer
......@@ -69,13 +65,19 @@ class Runner:
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward)
reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
val_reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
train_dataloader, val_dataloader = create_dataloader(
config=config.data, tokenizer=tokenizer, processor=processor
)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
......@@ -96,17 +98,26 @@ def main():
default_config = OmegaConf.merge(default_config, file_config)
ppo_config = OmegaConf.merge(default_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config)
ppo_config: PPOConfig = OmegaConf.to_object(ppo_config)
ppo_config.deep_post_init()
if not ray.is_initialized():
runtime_env = {
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:False",
}
}
# this is for local ray cluster
if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(),
ignore_reinit_error=True,
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
runtime_env=runtime_env)
else:
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.init(runtime_env=runtime_env)
runner = Runner.remote()
ray.get(runner.run.remote(ppo_config))
......
......@@ -110,11 +110,11 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
}
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], num_gpus: int) -> Dict[str, Any]:
total_num_tokens = sum(batch.meta_info["global_token_num"])
time = timing_raw["step"]
return {
"perf/total_num_tokens": total_num_tokens,
"perf/time_per_step": time,
"perf/throughput": total_num_tokens / (time * n_gpus),
"perf/throughput": total_num_tokens / (time * num_gpus),
}
......@@ -30,7 +30,6 @@ import ray
import torch
from codetiming import Timer
from ray.experimental.tqdm_ray import tqdm
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin
......@@ -40,7 +39,6 @@ from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWo
from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.dataset import RLHFDataset, collate_fn
from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
......@@ -102,24 +100,16 @@ class ResourcePoolManager:
"""Get the resource pool of the worker."""
return self.resource_pool_dict[self.mapping[role]]
def get_n_gpus(self) -> int:
def get_num_gpus(self) -> int:
"""Get the number of gpus in this cluster."""
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
def _check_resource_available(self):
"""Check if the resource pool can be satisfied in this ray cluster."""
node_available_resources = ray.state.available_resources_per_node()
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}
# check total required gpus can be satisfied
total_available_gpus = sum(node_available_gpus.values())
total_required_gpus = sum(
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
)
if total_available_gpus < total_required_gpus:
raise ValueError(
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}."
)
gpus_available = ray.available_resources().get("GPU", 0)
gpus_required = self.get_num_gpus()
if gpus_available < gpus_required:
raise ValueError(f"Total available GPUs {gpus_available} is less than total desired GPUs {gpus_required}.")
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"):
......@@ -128,11 +118,8 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal
response_mask = data.batch["response_mask"]
# compute kl between ref_policy and current policy
if "ref_log_probs" in data.batch.keys():
kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
kld = kld * response_mask # (batch_size, response_length)
else:
kld = torch.zeros_like(response_mask, dtype=torch.float32)
data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld
......@@ -193,6 +180,8 @@ class RayPPOTrainer:
config: PPOConfig,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin],
train_dataloader: StatefulDataLoader,
val_dataloader: StatefulDataLoader,
role_worker_mapping: dict[Role, Type[Worker]],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
......@@ -201,6 +190,8 @@ class RayPPOTrainer:
):
self.tokenizer = tokenizer
self.processor = processor
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
......@@ -262,78 +253,13 @@ class RayPPOTrainer:
):
raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.")
self._create_dataloader()
def _create_dataloader(self) -> None:
self.train_dataset = RLHFDataset(
data_path=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
answer_key=self.config.data.answer_key,
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
format_prompt=self.config.data.format_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.seed)
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.rollout_batch_size,
sampler=sampler,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=True,
)
self.val_dataset = RLHFDataset(
data_path=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
answer_key=self.config.data.answer_key,
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
format_prompt=self.config.data.format_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset)
if self.config.data.val_batch_size == -1
else self.config.data.val_batch_size,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")
if self.config.trainer.max_steps is not None:
training_steps = self.config.trainer.max_steps
if config.trainer.max_steps is not None:
self.training_steps = config.trainer.max_steps
else:
training_steps = len(self.train_dataloader) * self.config.trainer.total_episodes
self.training_steps = len(train_dataloader) * config.trainer.total_episodes
self.training_steps = training_steps
self.config.worker.actor.optim.training_steps = training_steps
self.config.worker.critic.optim.training_steps = training_steps
config.worker.actor.optim.training_steps = self.training_steps
config.worker.critic.optim.training_steps = self.training_steps
print(f"Total training steps: {self.training_steps}")
def _maybe_log_val_generations(
......@@ -366,10 +292,10 @@ class RayPPOTrainer:
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
if "multi_modal_inputs" in test_batch.non_tensor_batch.keys():
if "multi_modal_data" in test_batch.non_tensor_batch.keys():
test_gen_batch = test_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
)
else:
test_gen_batch = test_batch.pop(
......@@ -567,10 +493,10 @@ class RayPPOTrainer:
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
if "multi_modal_inputs" in batch.non_tensor_batch.keys():
if "multi_modal_data" in batch.non_tensor_batch.keys():
gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
)
else:
gen_batch = batch.pop(
......@@ -604,6 +530,7 @@ class RayPPOTrainer:
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
batch.non_tensor_batch.pop("multi_modal_data", None)
# compute reward
with _timer("reward", timing_raw):
......@@ -694,10 +621,10 @@ class RayPPOTrainer:
self._save_checkpoint()
# collect metrics
n_gpus = self.resource_pool_manager.get_n_gpus()
num_gpus = self.resource_pool_manager.get_num_gpus()
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, num_gpus=num_gpus))
self.logger.log(data=metrics, step=self.global_step)
......
......@@ -13,13 +13,12 @@
# limitations under the License.
import os
import warnings
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict, set_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from .checkpoint_manager import BaseCheckpointManager
......@@ -59,21 +58,18 @@ class FSDPCheckpointManager(BaseCheckpointManager):
extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
model_state_dict = torch.load(model_path, weights_only=False)
optimizer_state_dict = torch.load(optim_path, weights_only=False)
optim_state_dict = torch.load(optim_path, weights_only=False)
extra_state_dict = torch.load(extra_state_path, weights_only=False)
lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
state_dict_options = StateDictOptions(cpu_offload=True)
set_state_dict(
model=self.model,
optimizers=self.optimizer,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
options=state_dict_options,
)
self.lr_scheduler.load_state_dict(extra_state_dict["lr_scheduler"])
# recover random state
if "rng" in extra_state_dict:
......@@ -84,24 +80,10 @@ class FSDPCheckpointManager(BaseCheckpointManager):
dist.barrier()
# every rank will save its own model and optim shard
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config):
model_state_dict = self.model.state_dict()
if self.optimizer is not None:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
if self.lr_scheduler is not None:
lr_scheduler_state_dict = self.lr_scheduler.state_dict()
else:
lr_scheduler_state_dict = None
state_dict_options = StateDictOptions(cpu_offload=True)
model_state_dict, optim_state_dict = get_state_dict(self.model, self.optimizer, options=state_dict_options)
extra_state_dict = {
"lr_scheduler": lr_scheduler_state_dict,
"lr_scheduler": self.lr_scheduler.state_dict(),
"rng": self.get_rng_state(),
}
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
......@@ -112,9 +94,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
torch.save(model_state_dict, model_path)
if self.optimizer is not None:
torch.save(optimizer_state_dict, optim_path)
torch.save(optim_state_dict, optim_path)
torch.save(extra_state_dict, extra_path)
# wait for everyone to dump to local
......
......@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from datasets import load_dataset
from jinja2 import Template
from PIL import Image
from PIL.Image import Image as ImageObject
from torch.utils.data import Dataset
......@@ -90,9 +91,10 @@ class RLHFDataset(Dataset, ImageProcessMixin):
image_key: str = "images",
max_prompt_length: int = 1024,
truncation: str = "error",
format_prompt: str = None,
max_pixels: int = None,
min_pixels: int = None,
format_prompt: Optional[str] = None,
max_pixels: Optional[int] = None,
min_pixels: Optional[int] = None,
filter_overlong_prompts: bool = True,
):
self.tokenizer = tokenizer
self.processor = processor
......@@ -101,9 +103,9 @@ class RLHFDataset(Dataset, ImageProcessMixin):
self.image_key = image_key
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.format_prompt = format_prompt
self.max_pixels = max_pixels
self.min_pixels = min_pixels
self.filter_overlong_prompts = filter_overlong_prompts
if "@" in data_path:
data_path, data_split = data_path.split("@")
......@@ -111,22 +113,29 @@ class RLHFDataset(Dataset, ImageProcessMixin):
data_split = "train"
if os.path.isdir(data_path):
# when we use dataset builder, we should always refer to the train split
self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
elif os.path.isfile(data_path):
self.dataset = load_dataset("parquet", data_files=data_path, split="train")
else: # remote dataset
else:
# load remote dataset from huggingface hub
self.dataset = load_dataset(data_path, split=data_split)
def __len__(self):
return len(self.dataset)
self.format_prompt = None
if format_prompt:
with open(format_prompt, encoding="utf-8") as f:
self.format_prompt = f.read()
def __getitem__(self, index):
row_dict: dict = self.dataset[index]
prompt_str: str = row_dict[self.prompt_key]
if self.filter_overlong_prompts:
self.dataset = self.dataset.filter(self._filter_overlong_prompts, desc="Filtering overlong prompts")
def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]:
prompt_str: str = example[self.prompt_key]
if self.format_prompt:
prompt_str = prompt_str + " " + self.format_prompt.strip()
format_prompt = Template(self.format_prompt.strip())
prompt_str = format_prompt.render(content=prompt_str)
if self.image_key in row_dict:
if self.image_key in example:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list = []
for i, content in enumerate(prompt_str.split("<image>")):
......@@ -136,28 +145,47 @@ class RLHFDataset(Dataset, ImageProcessMixin):
if content:
content_list.append({"type": "text", "text": content})
messages = [{"role": "user", "content": content_list}]
return [{"role": "user", "content": content_list}]
else:
return [{"role": "user", "content": prompt_str}]
def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool:
messages = self._build_messages(example)
processing_class = self.processor if self.processor is not None else self.tokenizer
return (
len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length
)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
example: dict = self.dataset[index]
messages = self._build_messages(example)
if self.image_key in example:
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
images = [self.process_image(image) for image in example.pop(self.image_key)]
model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
row_dict["multi_modal_data"] = {"image": images}
row_dict["multi_modal_inputs"] = dict(model_inputs)
example["multi_modal_data"] = {"image": images}
example["multi_modal_inputs"] = dict(model_inputs)
else:
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor":
# qwen2vl mrope
position_ids = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=model_inputs["image_grid_thw"],
image_grid_thw=model_inputs.get("image_grid_thw"),
attention_mask=attention_mask,
) # (3, seq_length)
else:
messages = [{"role": "user", "content": prompt_str}]
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
input_ids, attention_mask, position_ids = VF.postprocess_data(
......@@ -169,9 +197,18 @@ class RLHFDataset(Dataset, ImageProcessMixin):
left_pad=True,
truncation=self.truncation,
)
row_dict["input_ids"] = input_ids
row_dict["attention_mask"] = attention_mask
row_dict["position_ids"] = position_ids
row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False)
row_dict["ground_truth"] = row_dict.pop(self.answer_key)
return row_dict
raw_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
if len(raw_prompt_ids) > self.max_prompt_length:
if self.truncation == "left":
raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
elif self.truncation == "right":
raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
elif self.truncation == "error":
raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
example["input_ids"] = input_ids
example["attention_mask"] = attention_mask
example["position_ids"] = position_ids
example["raw_prompt_ids"] = raw_prompt_ids
example["ground_truth"] = example.pop(self.answer_key)
return example
......@@ -71,7 +71,7 @@ class TensorBoardLogger(Logger):
os.makedirs(tensorboard_dir, exist_ok=True)
print(f"Saving tensorboard log to {tensorboard_dir}.")
self.writer = SummaryWriter(tensorboard_dir)
self.writer.add_hparams(hparam_dict=flatten_dict(config), metric_dict={})
self.writer.add_hparams(hparam_dict=flatten_dict(config), metric_dict={"placeholder": 0})
def log(self, data: Dict[str, Any], step: int) -> None:
for key, value in data.items():
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .math import math_compute_score
from .r1v import r1v_compute_score
__all__ = ["math_compute_score", "r1v_compute_score"]
......@@ -15,40 +15,28 @@
import torch
HALF_LIST = [16, "16", "fp16", "float16"]
FLOAT_LIST = [32, "32", "fp32", "float32"]
HALF_LIST = ["fp16", "float16"]
FLOAT_LIST = ["fp32", "float32"]
BFLOAT_LIST = ["bf16", "bfloat16"]
class PrecisionType:
"""Type of precision used.
>>> PrecisionType.HALF == 16
True
>>> PrecisionType.HALF in (16, "16")
True
"""
HALF = "16"
FLOAT = "32"
FULL = "64"
BFLOAT = "bf16"
MIXED = "mixed"
"""Type of precision used."""
@staticmethod
def is_fp16(precision):
def is_fp16(precision: str) -> bool:
return precision in HALF_LIST
@staticmethod
def is_fp32(precision):
def is_fp32(precision: str) -> bool:
return precision in FLOAT_LIST
@staticmethod
def is_bf16(precision):
def is_bf16(precision: str) -> bool:
return precision in BFLOAT_LIST
@staticmethod
def to_dtype(precision) -> torch.dtype:
def to_dtype(precision: str) -> torch.dtype:
if precision in HALF_LIST:
return torch.float16
elif precision in FLOAT_LIST:
......@@ -56,7 +44,7 @@ class PrecisionType:
elif precision in BFLOAT_LIST:
return torch.bfloat16
else:
raise RuntimeError(f"unexpected precision: {precision}")
raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod
def to_str(precision: torch.dtype) -> str:
......@@ -67,4 +55,4 @@ class PrecisionType:
elif precision == torch.bfloat16:
return "bfloat16"
else:
raise RuntimeError(f"unexpected precision: {precision}")
raise RuntimeError(f"Unexpected precision: {precision}")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright Meta Platforms, Inc. and affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,6 +23,8 @@ import torch.distributed
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from .torch_dtypes import PrecisionType
try:
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
......@@ -177,7 +180,7 @@ def postprocess_data(
attention_mask = attention_mask[..., :max_length]
position_ids = position_ids[..., :max_length]
elif truncation == "error":
raise NotImplementedError(f"{seq_length} is larger than {max_length}.")
raise RuntimeError(f"Input sequence length {seq_length} is longer than max length {max_length}.")
else:
raise NotImplementedError(f"Unknown truncation method {truncation}.")
......@@ -207,11 +210,18 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
eps: float = 1e-8,
weight_decay: float = 0.0,
use_kahan_summation: bool = True,
momentum_dtype: torch.dtype = torch.bfloat16,
variance_dtype: torch.dtype = torch.bfloat16,
compensation_buffer_dtype: torch.dtype = torch.bfloat16,
momentum_dtype: str = "bfloat16",
variance_dtype: str = "bfloat16",
compensation_buffer_dtype: str = "bfloat16",
):
"""
AnyPrecisionAdamW: a flexible precision AdamW optimizer
with optional Kahan summation for high precision weight updates.
Allows direct control over momentum, variance and auxiliary compensation buffer dtypes.
Optional Kahan summation is used to offset precision reduction for the weight updates.
This allows full training in BFloat16 (equal or better than FP32 results in many cases)
due to high precision weight updates.
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3)
......@@ -270,10 +280,11 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
eps = group["eps"]
use_kahan_summation = group["use_kahan_summation"]
momentum_dtype = group["momentum_dtype"]
variance_dtype = group["variance_dtype"]
compensation_buffer_dtype = group["compensation_buffer_dtype"]
momentum_dtype = PrecisionType.to_dtype(group["momentum_dtype"])
variance_dtype = PrecisionType.to_dtype(group["variance_dtype"])
compensation_buffer_dtype = PrecisionType.to_dtype(group["compensation_buffer_dtype"])
for p in group["params"]:
assert isinstance(p, torch.Tensor) # lint
if p.grad is None:
continue
......
......@@ -12,15 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BasePPOActor
from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
from .dp_actor import DataParallelPPOActor
__all__ = [
"ActorConfig",
"BasePPOActor",
"DataParallelPPOActor",
"FSDPConfig",
"ModelConfig",
"OptimConfig",
......
......@@ -15,6 +15,7 @@
Actor config
"""
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
......@@ -32,6 +33,12 @@ class ModelConfig:
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.model_path is not None and os.path.exists(self.model_path):
self.model_path = os.path.abspath(self.model_path)
if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
self.tokenizer_path = os.path.abspath(self.tokenizer_path)
@dataclass
class OptimConfig:
......
......@@ -20,9 +20,11 @@ from collections import defaultdict
from typing import Any, Dict, Optional
import torch
from einops import rearrange
from ray.experimental.tqdm_ray import tqdm
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers.modeling_flash_attention_utils import index_first_axis, pad_input, unpad_input
from ...protocol import DataProto
from ...trainer import core_algos
......@@ -33,12 +35,6 @@ from .base import BasePPOActor
from .config import ActorConfig
try:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
except ImportError:
pass
__all__ = ["DataParallelPPOActor"]
......
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BasePPOCritic
from .config import CriticConfig, ModelConfig
from .dp_critic import DataParallelPPOCritic
from .config import CriticConfig
__all__ = ["BasePPOCritic", "CriticConfig", "DataParallelPPOCritic", "ModelConfig"]
__all__ = ["CriticConfig"]
......@@ -54,9 +54,7 @@ from ..utils.model_utils import print_gpu_memory_usage, print_model_size
from ..utils.tokenizer import get_processor, get_tokenizer
from ..utils.torch_dtypes import PrecisionType
from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_with_warmup
from .actor import DataParallelPPOActor
from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig
from .critic import DataParallelPPOCritic
from .rollout import vLLMRollout
from .sharding_manager import FSDPVLLMShardingManager
from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
......@@ -264,6 +262,9 @@ class FSDPWorker(Worker):
else:
sync_module_states = False
param_init_fn = None
## TODO: 模型指定到卡
rank = torch.cuda.set_device(self.rank)
model = model.to(rank)
self.fsdp_module = FSDP(
model,
......@@ -365,6 +366,8 @@ class FSDPWorker(Worker):
print_gpu_memory_usage(f"After offload {role} optimizer during init")
if self._is_actor:
from .actor.dp_actor import DataParallelPPOActor # lazy import
self.actor = DataParallelPPOActor(
config=self.config.actor,
actor_module=self.fsdp_module,
......@@ -372,6 +375,8 @@ class FSDPWorker(Worker):
)
if self._is_critic:
from .critic.dp_critic import DataParallelPPOCritic # lazy import
self.critic = DataParallelPPOCritic(
config=self.config,
critic_module=self.fsdp_module,
......@@ -382,6 +387,8 @@ class FSDPWorker(Worker):
self._build_rollout()
if self._is_ref:
from .actor.dp_actor import DataParallelPPOActor # lazy import
self.ref_policy = DataParallelPPOActor(
config=self.config.ref,
actor_module=self.fsdp_module,
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from .config import RewardConfig
from .custom import CustomRewardManager
from .function import FunctionRewardManager
__all__ = ["CustomRewardManager", "RewardConfig"]
__all__ = ["FunctionRewardManager", "RewardConfig"]
......@@ -15,11 +15,28 @@
Reward config
"""
from dataclasses import dataclass
import os
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class RewardConfig:
reward_type: str = "function"
score_function: str = "math"
score_function: Optional[str] = None
score_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True
"""auto keys"""
score_function_name: Optional[str] = field(default=None, init=False)
def post_init(self):
if self.score_function is not None:
if ":" not in self.score_function:
self.score_function_name = "main"
else:
self.score_function, self.score_function_name = self.score_function.split(":", maxsplit=1)
if os.path.exists(self.score_function):
self.score_function = os.path.abspath(self.score_function)
else:
self.score_function = None
......@@ -12,34 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import os
import sys
from collections import defaultdict
from typing import Callable, Dict, List, Tuple, TypedDict
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
import torch
from transformers import PreTrainedTokenizer
from ...protocol import DataProto
from ...utils.reward_score import math_compute_score, r1v_compute_score
from .config import RewardConfig
class RewardScore(TypedDict):
overall: float
format: float
accuracy: float
format: Optional[float]
accuracy: Optional[float]
ScoreFunction = Callable[[str, str], RewardScore]
@dataclass
class FunctionRewardManager:
config: RewardConfig
tokenizer: PreTrainedTokenizer
def __post_init__(self):
"""Load score function."""
if self.config.score_function is None:
raise ValueError("Score function is not provided.")
if not os.path.exists(self.config.score_function):
raise FileNotFoundError(f"Score function file {self.config.score_function} not found.")
spec = importlib.util.spec_from_file_location("custom_score_fn", self.config.score_function)
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_score_fn"] = module
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Failed to load score function: {e}")
if not hasattr(module, self.config.score_function_name):
raise AttributeError(f"Module {module} does not have function {self.config.score_function_name}.")
class CustomRewardManager:
def __init__(self, tokenizer: PreTrainedTokenizer, config: RewardConfig):
self.config = config
self.tokenizer = tokenizer
if config.score_function == "math":
self.compute_score: Callable[[str, str], RewardScore] = math_compute_score
elif config.score_function == "r1v":
self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score
else:
raise NotImplementedError(f"Unknown score function {config.score_function}.")
score_fn: ScoreFunction = getattr(module, self.config.score_function_name)
print(f"Using score function `{self.config.score_function_name}` from `{self.config.score_function}`.")
self.score_fn = partial(score_fn, **self.config.score_function_kwargs)
def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
......@@ -56,7 +79,7 @@ class CustomRewardManager:
)
ground_truth = data_item.non_tensor_batch["ground_truth"]
score = self.compute_score(response_str, ground_truth)
score = self.score_fn(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment