Commit 2369eb2b authored by chenych's avatar chenych
Browse files

update

parent ac9d2b05
...@@ -98,7 +98,7 @@ class RayResourcePool(ResourcePool): ...@@ -98,7 +98,7 @@ class RayResourcePool(ResourcePool):
# print(f"pg_name_prefix = {pg_name_prefix}") # print(f"pg_name_prefix = {pg_name_prefix}")
pg_scheme = [ pg_scheme = [
[ [
{"CPU": self.max_collocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_collocate_count} {"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count}
for _ in range(process_count) for _ in range(process_count)
] ]
for process_count in self._store for process_count in self._store
...@@ -145,8 +145,8 @@ def extract_pg_from_exist( ...@@ -145,8 +145,8 @@ def extract_pg_from_exist(
def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool: def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:
assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not" assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not"
assert rp1.max_collocate_count == rp2.max_collocate_count, ( assert rp1.max_colocate_count == rp2.max_colocate_count, (
"Both RayResourcePool must has the same max_collocate_count" "Both RayResourcePool must has the same max_colocate_count"
) )
assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node" assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node"
assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool" assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool"
...@@ -259,7 +259,7 @@ class RayWorkerGroup(WorkerGroup): ...@@ -259,7 +259,7 @@ class RayWorkerGroup(WorkerGroup):
world_size = resource_pool.world_size world_size = resource_pool.world_size
self._world_size = world_size self._world_size = world_size
# cia.add_kwarg("_world_size", world_size) # cia.add_kwarg("_world_size", world_size)
num_gpus = 1 / resource_pool.max_collocate_count num_gpus = 1 / resource_pool.max_colocate_count
rank = -1 rank = -1
local_world_size = resource_pool.store[0] local_world_size = resource_pool.store[0]
...@@ -300,7 +300,7 @@ class RayWorkerGroup(WorkerGroup): ...@@ -300,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
if rank == 0: if rank == 0:
register_center_actor = None register_center_actor = None
for _ in range(360): for _ in range(120):
if f"{self.name_prefix}_register_center" not in list_named_actors(): if f"{self.name_prefix}_register_center" not in list_named_actors():
time.sleep(1) time.sleep(1)
else: else:
......
...@@ -47,6 +47,14 @@ class DataConfig: ...@@ -47,6 +47,14 @@ class DataConfig:
seed: int = 1 seed: int = 1
max_pixels: int = 4194304 max_pixels: int = 4194304
min_pixels: int = 262144 min_pixels: int = 262144
filter_overlong_prompts: bool = True
def post_init(self):
if self.format_prompt is not None:
if os.path.exists(self.format_prompt):
self.format_prompt = os.path.abspath(self.format_prompt)
else:
self.format_prompt = None
@dataclass @dataclass
...@@ -86,6 +94,10 @@ class TrainerConfig: ...@@ -86,6 +94,10 @@ class TrainerConfig:
if self.save_checkpoint_path is None: if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name) self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path)
if self.load_checkpoint_path is not None:
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
@dataclass @dataclass
class PPOConfig: class PPOConfig:
...@@ -97,6 +109,7 @@ class PPOConfig: ...@@ -97,6 +109,7 @@ class PPOConfig:
def post_init(self): def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length self.worker.rollout.response_length = self.data.max_response_length
self.worker.rollout.trust_remote_code = self.worker.actor.model.trust_remote_code
self.worker.actor.disable_kl = self.algorithm.disable_kl self.worker.actor.disable_kl = self.algorithm.disable_kl
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
self.worker.actor.kl_penalty = self.algorithm.kl_penalty self.worker.actor.kl_penalty = self.algorithm.kl_penalty
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..utils.dataset import RLHFDataset, collate_fn
from .config import DataConfig
def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin]) -> None:
train_dataset = RLHFDataset(
data_path=config.train_files,
tokenizer=tokenizer,
processor=processor,
prompt_key=config.prompt_key,
answer_key=config.answer_key,
image_key=config.image_key,
max_prompt_length=config.max_prompt_length,
truncation="right",
format_prompt=config.format_prompt,
min_pixels=config.min_pixels,
max_pixels=config.max_pixels,
filter_overlong_prompts=config.filter_overlong_prompts,
)
# use sampler for better ckpt resume
if config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(config.seed)
sampler = RandomSampler(data_source=train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=train_dataset)
train_dataloader = StatefulDataLoader(
dataset=train_dataset,
batch_size=config.rollout_batch_size,
sampler=sampler,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=True,
)
val_dataset = RLHFDataset(
data_path=config.val_files,
tokenizer=tokenizer,
processor=processor,
prompt_key=config.prompt_key,
answer_key=config.answer_key,
image_key=config.image_key,
max_prompt_length=config.max_prompt_length,
truncation="right",
format_prompt=config.format_prompt,
min_pixels=config.min_pixels,
max_pixels=config.max_pixels,
filter_overlong_prompts=config.filter_overlong_prompts,
)
val_dataloader = StatefulDataLoader(
dataset=val_dataset,
batch_size=len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
)
assert len(train_dataloader) >= 1
assert len(val_dataloader) >= 1
print(f"Size of train dataloader: {len(train_dataloader)}")
print(f"Size of val dataloader: {len(val_dataloader)}")
return train_dataloader, val_dataloader
...@@ -11,21 +11,18 @@ ...@@ -11,21 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import json import json
import torch import torch
import ray import ray
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ..single_controller.ray import RayWorkerGroup from ..single_controller.ray import RayWorkerGroup
from ..utils.tokenizer import get_processor, get_tokenizer from ..utils.tokenizer import get_processor, get_tokenizer
from ..workers.fsdp_workers import FSDPWorker from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import CustomRewardManager from ..workers.reward import FunctionRewardManager
from .config import PPOConfig from .config import PPOConfig
from .data_loader import create_dataloader
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
...@@ -36,7 +33,6 @@ class Runner: ...@@ -36,7 +33,6 @@ class Runner:
def run(self, config: PPOConfig): def run(self, config: PPOConfig):
# print config # print config
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2)) print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer # instantiate tokenizer
...@@ -69,13 +65,19 @@ class Runner: ...@@ -69,13 +65,19 @@ class Runner:
} }
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward) reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward) val_reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
train_dataloader, val_dataloader = create_dataloader(
config=config.data, tokenizer=tokenizer, processor=processor
)
trainer = RayPPOTrainer( trainer = RayPPOTrainer(
config=config, config=config,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
role_worker_mapping=role_worker_mapping, role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager, resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls, ray_worker_group_cls=ray_worker_group_cls,
...@@ -96,17 +98,26 @@ def main(): ...@@ -96,17 +98,26 @@ def main():
default_config = OmegaConf.merge(default_config, file_config) default_config = OmegaConf.merge(default_config, file_config)
ppo_config = OmegaConf.merge(default_config, cli_args) ppo_config = OmegaConf.merge(default_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config) ppo_config: PPOConfig = OmegaConf.to_object(ppo_config)
ppo_config.deep_post_init()
if not ray.is_initialized(): if not ray.is_initialized():
runtime_env = {
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:False",
}
}
# this is for local ray cluster # this is for local ray cluster
if torch.version.hip is not None: if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(), ray.init(num_gpus=torch.cuda.device_count(),
ignore_reinit_error=True, ignore_reinit_error=True,
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) runtime_env=runtime_env)
else: else:
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}) ray.init(runtime_env=runtime_env)
runner = Runner.remote() runner = Runner.remote()
ray.get(runner.run.remote(ppo_config)) ray.get(runner.run.remote(ppo_config))
......
...@@ -110,11 +110,11 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di ...@@ -110,11 +110,11 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
} }
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]: def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], num_gpus: int) -> Dict[str, Any]:
total_num_tokens = sum(batch.meta_info["global_token_num"]) total_num_tokens = sum(batch.meta_info["global_token_num"])
time = timing_raw["step"] time = timing_raw["step"]
return { return {
"perf/total_num_tokens": total_num_tokens, "perf/total_num_tokens": total_num_tokens,
"perf/time_per_step": time, "perf/time_per_step": time,
"perf/throughput": total_num_tokens / (time * n_gpus), "perf/throughput": total_num_tokens / (time * num_gpus),
} }
...@@ -30,7 +30,6 @@ import ray ...@@ -30,7 +30,6 @@ import ray
import torch import torch
from codetiming import Timer from codetiming import Timer
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
...@@ -40,7 +39,6 @@ from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWo ...@@ -40,7 +39,6 @@ from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWo
from ..single_controller.ray.base import create_colocated_worker_cls from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.dataset import RLHFDataset, collate_fn
from ..utils.logger import Tracker from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str from ..utils.py_functional import convert_dict_to_str
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
...@@ -102,24 +100,16 @@ class ResourcePoolManager: ...@@ -102,24 +100,16 @@ class ResourcePoolManager:
"""Get the resource pool of the worker.""" """Get the resource pool of the worker."""
return self.resource_pool_dict[self.mapping[role]] return self.resource_pool_dict[self.mapping[role]]
def get_n_gpus(self) -> int: def get_num_gpus(self) -> int:
"""Get the number of gpus in this cluster.""" """Get the number of gpus in this cluster."""
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
def _check_resource_available(self): def _check_resource_available(self):
"""Check if the resource pool can be satisfied in this ray cluster.""" """Check if the resource pool can be satisfied in this ray cluster."""
node_available_resources = ray.state.available_resources_per_node() gpus_available = ray.available_resources().get("GPU", 0)
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} gpus_required = self.get_num_gpus()
if gpus_available < gpus_required:
# check total required gpus can be satisfied raise ValueError(f"Total available GPUs {gpus_available} is less than total desired GPUs {gpus_required}.")
total_available_gpus = sum(node_available_gpus.values())
total_required_gpus = sum(
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
)
if total_available_gpus < total_required_gpus:
raise ValueError(
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}."
)
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"): def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"):
...@@ -128,11 +118,8 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal ...@@ -128,11 +118,8 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal
response_mask = data.batch["response_mask"] response_mask = data.batch["response_mask"]
# compute kl between ref_policy and current policy # compute kl between ref_policy and current policy
if "ref_log_probs" in data.batch.keys(): kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty) kld = kld * response_mask # (batch_size, response_length)
kld = kld * response_mask # (batch_size, response_length)
else:
kld = torch.zeros_like(response_mask, dtype=torch.float32)
data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld
...@@ -193,6 +180,8 @@ class RayPPOTrainer: ...@@ -193,6 +180,8 @@ class RayPPOTrainer:
config: PPOConfig, config: PPOConfig,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin], processor: Optional[ProcessorMixin],
train_dataloader: StatefulDataLoader,
val_dataloader: StatefulDataLoader,
role_worker_mapping: dict[Role, Type[Worker]], role_worker_mapping: dict[Role, Type[Worker]],
resource_pool_manager: ResourcePoolManager, resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup, ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
...@@ -201,6 +190,8 @@ class RayPPOTrainer: ...@@ -201,6 +190,8 @@ class RayPPOTrainer:
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.config = config self.config = config
self.reward_fn = reward_fn self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn self.val_reward_fn = val_reward_fn
...@@ -262,78 +253,13 @@ class RayPPOTrainer: ...@@ -262,78 +253,13 @@ class RayPPOTrainer:
): ):
raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.") raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.")
self._create_dataloader() if config.trainer.max_steps is not None:
self.training_steps = config.trainer.max_steps
def _create_dataloader(self) -> None:
self.train_dataset = RLHFDataset(
data_path=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
answer_key=self.config.data.answer_key,
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
format_prompt=self.config.data.format_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.seed)
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.rollout_batch_size,
sampler=sampler,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=True,
)
self.val_dataset = RLHFDataset(
data_path=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
answer_key=self.config.data.answer_key,
image_key=self.config.data.image_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
format_prompt=self.config.data.format_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset)
if self.config.data.val_batch_size == -1
else self.config.data.val_batch_size,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")
if self.config.trainer.max_steps is not None:
training_steps = self.config.trainer.max_steps
else: else:
training_steps = len(self.train_dataloader) * self.config.trainer.total_episodes self.training_steps = len(train_dataloader) * config.trainer.total_episodes
self.training_steps = training_steps config.worker.actor.optim.training_steps = self.training_steps
self.config.worker.actor.optim.training_steps = training_steps config.worker.critic.optim.training_steps = self.training_steps
self.config.worker.critic.optim.training_steps = training_steps
print(f"Total training steps: {self.training_steps}") print(f"Total training steps: {self.training_steps}")
def _maybe_log_val_generations( def _maybe_log_val_generations(
...@@ -366,10 +292,10 @@ class RayPPOTrainer: ...@@ -366,10 +292,10 @@ class RayPPOTrainer:
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts) sample_inputs.extend(input_texts)
if "multi_modal_inputs" in test_batch.non_tensor_batch.keys(): if "multi_modal_data" in test_batch.non_tensor_batch.keys():
test_gen_batch = test_batch.pop( test_gen_batch = test_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"], batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
) )
else: else:
test_gen_batch = test_batch.pop( test_gen_batch = test_batch.pop(
...@@ -567,10 +493,10 @@ class RayPPOTrainer: ...@@ -567,10 +493,10 @@ class RayPPOTrainer:
batch: DataProto = DataProto.from_single_dict(batch_dict) batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation # pop those keys for generation
if "multi_modal_inputs" in batch.non_tensor_batch.keys(): if "multi_modal_data" in batch.non_tensor_batch.keys():
gen_batch = batch.pop( gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"], batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
) )
else: else:
gen_batch = batch.pop( gen_batch = batch.pop(
...@@ -604,6 +530,7 @@ class RayPPOTrainer: ...@@ -604,6 +530,7 @@ class RayPPOTrainer:
# repeat to align with repeated responses in rollout # repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True) batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
batch = batch.union(gen_batch_output) batch = batch.union(gen_batch_output)
batch.non_tensor_batch.pop("multi_modal_data", None)
# compute reward # compute reward
with _timer("reward", timing_raw): with _timer("reward", timing_raw):
...@@ -694,10 +621,10 @@ class RayPPOTrainer: ...@@ -694,10 +621,10 @@ class RayPPOTrainer:
self._save_checkpoint() self._save_checkpoint()
# collect metrics # collect metrics
n_gpus = self.resource_pool_manager.get_n_gpus() num_gpus = self.resource_pool_manager.get_num_gpus()
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, num_gpus=num_gpus))
self.logger.log(data=metrics, step=self.global_step) self.logger.log(data=metrics, step=self.global_step)
......
...@@ -13,13 +13,12 @@ ...@@ -13,13 +13,12 @@
# limitations under the License. # limitations under the License.
import os import os
import warnings
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict, set_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from .checkpoint_manager import BaseCheckpointManager from .checkpoint_manager import BaseCheckpointManager
...@@ -59,21 +58,18 @@ class FSDPCheckpointManager(BaseCheckpointManager): ...@@ -59,21 +58,18 @@ class FSDPCheckpointManager(BaseCheckpointManager):
extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.") print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
model_state_dict = torch.load(model_path, weights_only=False) model_state_dict = torch.load(model_path, weights_only=False)
optimizer_state_dict = torch.load(optim_path, weights_only=False) optim_state_dict = torch.load(optim_path, weights_only=False)
extra_state_dict = torch.load(extra_state_path, weights_only=False) extra_state_dict = torch.load(extra_state_path, weights_only=False)
lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) state_dict_options = StateDictOptions(cpu_offload=True)
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) set_state_dict(
with warnings.catch_warnings(): model=self.model,
warnings.simplefilter("ignore") optimizers=self.optimizer,
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): model_state_dict=model_state_dict,
self.model.load_state_dict(model_state_dict) optim_state_dict=optim_state_dict,
if self.optimizer is not None: options=state_dict_options,
self.optimizer.load_state_dict(optimizer_state_dict) )
self.lr_scheduler.load_state_dict(extra_state_dict["lr_scheduler"])
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
# recover random state # recover random state
if "rng" in extra_state_dict: if "rng" in extra_state_dict:
...@@ -84,38 +80,22 @@ class FSDPCheckpointManager(BaseCheckpointManager): ...@@ -84,38 +80,22 @@ class FSDPCheckpointManager(BaseCheckpointManager):
dist.barrier() dist.barrier()
# every rank will save its own model and optim shard # every rank will save its own model and optim shard
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) state_dict_options = StateDictOptions(cpu_offload=True)
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) model_state_dict, optim_state_dict = get_state_dict(self.model, self.optimizer, options=state_dict_options)
with warnings.catch_warnings(): extra_state_dict = {
warnings.simplefilter("ignore") "lr_scheduler": self.lr_scheduler.state_dict(),
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): "rng": self.get_rng_state(),
model_state_dict = self.model.state_dict() }
if self.optimizer is not None: model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optimizer_state_dict = self.optimizer.state_dict() optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
else: extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
optimizer_state_dict = None
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
if self.lr_scheduler is not None: print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
lr_scheduler_state_dict = self.lr_scheduler.state_dict() print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
else: torch.save(model_state_dict, model_path)
lr_scheduler_state_dict = None torch.save(optim_state_dict, optim_path)
torch.save(extra_state_dict, extra_path)
extra_state_dict = {
"lr_scheduler": lr_scheduler_state_dict,
"rng": self.get_rng_state(),
}
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
torch.save(model_state_dict, model_path)
if self.optimizer is not None:
torch.save(optimizer_state_dict, optim_path)
torch.save(extra_state_dict, extra_path)
# wait for everyone to dump to local # wait for everyone to dump to local
dist.barrier() dist.barrier()
......
...@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from jinja2 import Template
from PIL import Image from PIL import Image
from PIL.Image import Image as ImageObject from PIL.Image import Image as ImageObject
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -90,9 +91,10 @@ class RLHFDataset(Dataset, ImageProcessMixin): ...@@ -90,9 +91,10 @@ class RLHFDataset(Dataset, ImageProcessMixin):
image_key: str = "images", image_key: str = "images",
max_prompt_length: int = 1024, max_prompt_length: int = 1024,
truncation: str = "error", truncation: str = "error",
format_prompt: str = None, format_prompt: Optional[str] = None,
max_pixels: int = None, max_pixels: Optional[int] = None,
min_pixels: int = None, min_pixels: Optional[int] = None,
filter_overlong_prompts: bool = True,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
...@@ -101,9 +103,9 @@ class RLHFDataset(Dataset, ImageProcessMixin): ...@@ -101,9 +103,9 @@ class RLHFDataset(Dataset, ImageProcessMixin):
self.image_key = image_key self.image_key = image_key
self.max_prompt_length = max_prompt_length self.max_prompt_length = max_prompt_length
self.truncation = truncation self.truncation = truncation
self.format_prompt = format_prompt
self.max_pixels = max_pixels self.max_pixels = max_pixels
self.min_pixels = min_pixels self.min_pixels = min_pixels
self.filter_overlong_prompts = filter_overlong_prompts
if "@" in data_path: if "@" in data_path:
data_path, data_split = data_path.split("@") data_path, data_split = data_path.split("@")
...@@ -111,22 +113,29 @@ class RLHFDataset(Dataset, ImageProcessMixin): ...@@ -111,22 +113,29 @@ class RLHFDataset(Dataset, ImageProcessMixin):
data_split = "train" data_split = "train"
if os.path.isdir(data_path): if os.path.isdir(data_path):
# when we use dataset builder, we should always refer to the train split
self.dataset = load_dataset("parquet", data_dir=data_path, split="train") self.dataset = load_dataset("parquet", data_dir=data_path, split="train")
elif os.path.isfile(data_path): elif os.path.isfile(data_path):
self.dataset = load_dataset("parquet", data_files=data_path, split="train") self.dataset = load_dataset("parquet", data_files=data_path, split="train")
else: # remote dataset else:
# load remote dataset from huggingface hub
self.dataset = load_dataset(data_path, split=data_split) self.dataset = load_dataset(data_path, split=data_split)
def __len__(self): self.format_prompt = None
return len(self.dataset) if format_prompt:
with open(format_prompt, encoding="utf-8") as f:
self.format_prompt = f.read()
def __getitem__(self, index): if self.filter_overlong_prompts:
row_dict: dict = self.dataset[index] self.dataset = self.dataset.filter(self._filter_overlong_prompts, desc="Filtering overlong prompts")
prompt_str: str = row_dict[self.prompt_key]
def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]:
prompt_str: str = example[self.prompt_key]
if self.format_prompt: if self.format_prompt:
prompt_str = prompt_str + " " + self.format_prompt.strip() format_prompt = Template(self.format_prompt.strip())
prompt_str = format_prompt.render(content=prompt_str)
if self.image_key in row_dict: if self.image_key in example:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text # https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list = [] content_list = []
for i, content in enumerate(prompt_str.split("<image>")): for i, content in enumerate(prompt_str.split("<image>")):
...@@ -136,28 +145,47 @@ class RLHFDataset(Dataset, ImageProcessMixin): ...@@ -136,28 +145,47 @@ class RLHFDataset(Dataset, ImageProcessMixin):
if content: if content:
content_list.append({"type": "text", "text": content}) content_list.append({"type": "text", "text": content})
messages = [{"role": "user", "content": content_list}] return [{"role": "user", "content": content_list}]
else:
return [{"role": "user", "content": prompt_str}]
def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool:
messages = self._build_messages(example)
processing_class = self.processor if self.processor is not None else self.tokenizer
return (
len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length
)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
example: dict = self.dataset[index]
messages = self._build_messages(example)
if self.image_key in example:
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
images = [self.process_image(image) for image in row_dict.pop(self.image_key)] images = [self.process_image(image) for image in example.pop(self.image_key)]
model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt") model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0] input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0] attention_mask = model_inputs.pop("attention_mask")[0]
row_dict["multi_modal_data"] = {"image": images} example["multi_modal_data"] = {"image": images}
row_dict["multi_modal_inputs"] = dict(model_inputs) example["multi_modal_inputs"] = dict(model_inputs)
else:
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor":
# qwen2vl mrope # qwen2vl mrope
position_ids = get_rope_index( position_ids = get_rope_index(
self.processor, self.processor,
input_ids=input_ids, input_ids=input_ids,
image_grid_thw=model_inputs["image_grid_thw"], image_grid_thw=model_inputs.get("image_grid_thw"),
attention_mask=attention_mask, attention_mask=attention_mask,
) # (3, seq_length) ) # (3, seq_length)
else: else:
messages = [{"role": "user", "content": prompt_str}]
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")[0]
attention_mask = model_inputs.pop("attention_mask")[0]
position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,) position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
input_ids, attention_mask, position_ids = VF.postprocess_data( input_ids, attention_mask, position_ids = VF.postprocess_data(
...@@ -169,9 +197,18 @@ class RLHFDataset(Dataset, ImageProcessMixin): ...@@ -169,9 +197,18 @@ class RLHFDataset(Dataset, ImageProcessMixin):
left_pad=True, left_pad=True,
truncation=self.truncation, truncation=self.truncation,
) )
row_dict["input_ids"] = input_ids raw_prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
row_dict["attention_mask"] = attention_mask if len(raw_prompt_ids) > self.max_prompt_length:
row_dict["position_ids"] = position_ids if self.truncation == "left":
row_dict["raw_prompt_ids"] = self.tokenizer.encode(prompt, add_special_tokens=False) raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
row_dict["ground_truth"] = row_dict.pop(self.answer_key) elif self.truncation == "right":
return row_dict raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
elif self.truncation == "error":
raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
example["input_ids"] = input_ids
example["attention_mask"] = attention_mask
example["position_ids"] = position_ids
example["raw_prompt_ids"] = raw_prompt_ids
example["ground_truth"] = example.pop(self.answer_key)
return example
...@@ -71,7 +71,7 @@ class TensorBoardLogger(Logger): ...@@ -71,7 +71,7 @@ class TensorBoardLogger(Logger):
os.makedirs(tensorboard_dir, exist_ok=True) os.makedirs(tensorboard_dir, exist_ok=True)
print(f"Saving tensorboard log to {tensorboard_dir}.") print(f"Saving tensorboard log to {tensorboard_dir}.")
self.writer = SummaryWriter(tensorboard_dir) self.writer = SummaryWriter(tensorboard_dir)
self.writer.add_hparams(hparam_dict=flatten_dict(config), metric_dict={}) self.writer.add_hparams(hparam_dict=flatten_dict(config), metric_dict={"placeholder": 0})
def log(self, data: Dict[str, Any], step: int) -> None: def log(self, data: Dict[str, Any], step: int) -> None:
for key, value in data.items(): for key, value in data.items():
......
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .math import math_compute_score
from .r1v import r1v_compute_score
__all__ = ["math_compute_score", "r1v_compute_score"]
...@@ -15,40 +15,28 @@ ...@@ -15,40 +15,28 @@
import torch import torch
HALF_LIST = [16, "16", "fp16", "float16"] HALF_LIST = ["fp16", "float16"]
FLOAT_LIST = [32, "32", "fp32", "float32"] FLOAT_LIST = ["fp32", "float32"]
BFLOAT_LIST = ["bf16", "bfloat16"] BFLOAT_LIST = ["bf16", "bfloat16"]
class PrecisionType: class PrecisionType:
"""Type of precision used. """Type of precision used."""
>>> PrecisionType.HALF == 16
True
>>> PrecisionType.HALF in (16, "16")
True
"""
HALF = "16"
FLOAT = "32"
FULL = "64"
BFLOAT = "bf16"
MIXED = "mixed"
@staticmethod @staticmethod
def is_fp16(precision): def is_fp16(precision: str) -> bool:
return precision in HALF_LIST return precision in HALF_LIST
@staticmethod @staticmethod
def is_fp32(precision): def is_fp32(precision: str) -> bool:
return precision in FLOAT_LIST return precision in FLOAT_LIST
@staticmethod @staticmethod
def is_bf16(precision): def is_bf16(precision: str) -> bool:
return precision in BFLOAT_LIST return precision in BFLOAT_LIST
@staticmethod @staticmethod
def to_dtype(precision) -> torch.dtype: def to_dtype(precision: str) -> torch.dtype:
if precision in HALF_LIST: if precision in HALF_LIST:
return torch.float16 return torch.float16
elif precision in FLOAT_LIST: elif precision in FLOAT_LIST:
...@@ -56,7 +44,7 @@ class PrecisionType: ...@@ -56,7 +44,7 @@ class PrecisionType:
elif precision in BFLOAT_LIST: elif precision in BFLOAT_LIST:
return torch.bfloat16 return torch.bfloat16
else: else:
raise RuntimeError(f"unexpected precision: {precision}") raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod @staticmethod
def to_str(precision: torch.dtype) -> str: def to_str(precision: torch.dtype) -> str:
...@@ -67,4 +55,4 @@ class PrecisionType: ...@@ -67,4 +55,4 @@ class PrecisionType:
elif precision == torch.bfloat16: elif precision == torch.bfloat16:
return "bfloat16" return "bfloat16"
else: else:
raise RuntimeError(f"unexpected precision: {precision}") raise RuntimeError(f"Unexpected precision: {precision}")
# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright Meta Platforms, Inc. and affiliates
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,6 +23,8 @@ import torch.distributed ...@@ -22,6 +23,8 @@ import torch.distributed
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from .torch_dtypes import PrecisionType
try: try:
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
...@@ -177,7 +180,7 @@ def postprocess_data( ...@@ -177,7 +180,7 @@ def postprocess_data(
attention_mask = attention_mask[..., :max_length] attention_mask = attention_mask[..., :max_length]
position_ids = position_ids[..., :max_length] position_ids = position_ids[..., :max_length]
elif truncation == "error": elif truncation == "error":
raise NotImplementedError(f"{seq_length} is larger than {max_length}.") raise RuntimeError(f"Input sequence length {seq_length} is longer than max length {max_length}.")
else: else:
raise NotImplementedError(f"Unknown truncation method {truncation}.") raise NotImplementedError(f"Unknown truncation method {truncation}.")
...@@ -207,35 +210,42 @@ class AnyPrecisionAdamW(torch.optim.Optimizer): ...@@ -207,35 +210,42 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
eps: float = 1e-8, eps: float = 1e-8,
weight_decay: float = 0.0, weight_decay: float = 0.0,
use_kahan_summation: bool = True, use_kahan_summation: bool = True,
momentum_dtype: torch.dtype = torch.bfloat16, momentum_dtype: str = "bfloat16",
variance_dtype: torch.dtype = torch.bfloat16, variance_dtype: str = "bfloat16",
compensation_buffer_dtype: torch.dtype = torch.bfloat16, compensation_buffer_dtype: str = "bfloat16",
): ):
""" """
AnyPrecisionAdamW: a flexible precision AdamW optimizer
with optional Kahan summation for high precision weight updates.
Allows direct control over momentum, variance and auxiliary compensation buffer dtypes.
Optional Kahan summation is used to offset precision reduction for the weight updates.
This allows full training in BFloat16 (equal or better than FP32 results in many cases)
due to high precision weight updates.
Args: Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3) lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999)) running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2) weight_decay (float, optional): weight decay coefficient (default: 1e-2)
# Any Precision specific # Any Precision specific
use_kahan_summation = creates auxiliary buffer to ensure high precision use_kahan_summation = creates auxiliary buffer to ensure high precision
model param updates (default: False) model param updates (default: False)
momentum_dtype = dtype for momentum (default: bfloat16) momentum_dtype = dtype for momentum (default: bfloat16)
variance_dtype = dtype for uncentered variance (default: bfloat16) variance_dtype = dtype for uncentered variance (default: bfloat16)
compensation_buffer_dtype = dtype for Kahan summation buffer (default: bfloat16) compensation_buffer_dtype = dtype for Kahan summation buffer (default: bfloat16)
# Usage # Usage
This optimizer implements optimizer states, and Kahan summation This optimizer implements optimizer states, and Kahan summation
for high precision updates, all in user controlled dtypes. for high precision updates, all in user controlled dtypes.
Defaults are variance in BF16, Momentum in FP32. Defaults are variance in BF16, Momentum in FP32.
This can be run in FSDP mixed precision, amp, or full precision, This can be run in FSDP mixed precision, amp, or full precision,
depending on what training pipeline you wish to work with. depending on what training pipeline you wish to work with.
Setting to use_kahan_summation = False, and changing momentum and Setting to use_kahan_summation = False, and changing momentum and
variance dtypes to FP32, reverts this to a standard AdamW optimizer. variance dtypes to FP32, reverts this to a standard AdamW optimizer.
""" """
defaults = { defaults = {
...@@ -270,10 +280,11 @@ class AnyPrecisionAdamW(torch.optim.Optimizer): ...@@ -270,10 +280,11 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
eps = group["eps"] eps = group["eps"]
use_kahan_summation = group["use_kahan_summation"] use_kahan_summation = group["use_kahan_summation"]
momentum_dtype = group["momentum_dtype"] momentum_dtype = PrecisionType.to_dtype(group["momentum_dtype"])
variance_dtype = group["variance_dtype"] variance_dtype = PrecisionType.to_dtype(group["variance_dtype"])
compensation_buffer_dtype = group["compensation_buffer_dtype"] compensation_buffer_dtype = PrecisionType.to_dtype(group["compensation_buffer_dtype"])
for p in group["params"]: for p in group["params"]:
assert isinstance(p, torch.Tensor) # lint
if p.grad is None: if p.grad is None:
continue continue
......
...@@ -12,15 +12,11 @@ ...@@ -12,15 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .base import BasePPOActor
from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
from .dp_actor import DataParallelPPOActor
__all__ = [ __all__ = [
"ActorConfig", "ActorConfig",
"BasePPOActor",
"DataParallelPPOActor",
"FSDPConfig", "FSDPConfig",
"ModelConfig", "ModelConfig",
"OptimConfig", "OptimConfig",
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
Actor config Actor config
""" """
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
...@@ -32,6 +33,12 @@ class ModelConfig: ...@@ -32,6 +33,12 @@ class ModelConfig:
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.model_path is not None and os.path.exists(self.model_path):
self.model_path = os.path.abspath(self.model_path)
if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
self.tokenizer_path = os.path.abspath(self.tokenizer_path)
@dataclass @dataclass
class OptimConfig: class OptimConfig:
......
...@@ -20,9 +20,11 @@ from collections import defaultdict ...@@ -20,9 +20,11 @@ from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from einops import rearrange
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from torch import nn from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers.modeling_flash_attention_utils import index_first_axis, pad_input, unpad_input
from ...protocol import DataProto from ...protocol import DataProto
from ...trainer import core_algos from ...trainer import core_algos
...@@ -33,12 +35,6 @@ from .base import BasePPOActor ...@@ -33,12 +35,6 @@ from .base import BasePPOActor
from .config import ActorConfig from .config import ActorConfig
try:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
except ImportError:
pass
__all__ = ["DataParallelPPOActor"] __all__ = ["DataParallelPPOActor"]
......
...@@ -12,9 +12,7 @@ ...@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .base import BasePPOCritic from .config import CriticConfig
from .config import CriticConfig, ModelConfig
from .dp_critic import DataParallelPPOCritic
__all__ = ["BasePPOCritic", "CriticConfig", "DataParallelPPOCritic", "ModelConfig"] __all__ = ["CriticConfig"]
...@@ -54,9 +54,7 @@ from ..utils.model_utils import print_gpu_memory_usage, print_model_size ...@@ -54,9 +54,7 @@ from ..utils.model_utils import print_gpu_memory_usage, print_model_size
from ..utils.tokenizer import get_processor, get_tokenizer from ..utils.tokenizer import get_processor, get_tokenizer
from ..utils.torch_dtypes import PrecisionType from ..utils.torch_dtypes import PrecisionType
from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_with_warmup from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_with_warmup
from .actor import DataParallelPPOActor
from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig
from .critic import DataParallelPPOCritic
from .rollout import vLLMRollout from .rollout import vLLMRollout
from .sharding_manager import FSDPVLLMShardingManager from .sharding_manager import FSDPVLLMShardingManager
from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
...@@ -264,6 +262,9 @@ class FSDPWorker(Worker): ...@@ -264,6 +262,9 @@ class FSDPWorker(Worker):
else: else:
sync_module_states = False sync_module_states = False
param_init_fn = None param_init_fn = None
## TODO: 模型指定到卡
rank = torch.cuda.set_device(self.rank)
model = model.to(rank)
self.fsdp_module = FSDP( self.fsdp_module = FSDP(
model, model,
...@@ -365,6 +366,8 @@ class FSDPWorker(Worker): ...@@ -365,6 +366,8 @@ class FSDPWorker(Worker):
print_gpu_memory_usage(f"After offload {role} optimizer during init") print_gpu_memory_usage(f"After offload {role} optimizer during init")
if self._is_actor: if self._is_actor:
from .actor.dp_actor import DataParallelPPOActor # lazy import
self.actor = DataParallelPPOActor( self.actor = DataParallelPPOActor(
config=self.config.actor, config=self.config.actor,
actor_module=self.fsdp_module, actor_module=self.fsdp_module,
...@@ -372,6 +375,8 @@ class FSDPWorker(Worker): ...@@ -372,6 +375,8 @@ class FSDPWorker(Worker):
) )
if self._is_critic: if self._is_critic:
from .critic.dp_critic import DataParallelPPOCritic # lazy import
self.critic = DataParallelPPOCritic( self.critic = DataParallelPPOCritic(
config=self.config, config=self.config,
critic_module=self.fsdp_module, critic_module=self.fsdp_module,
...@@ -382,6 +387,8 @@ class FSDPWorker(Worker): ...@@ -382,6 +387,8 @@ class FSDPWorker(Worker):
self._build_rollout() self._build_rollout()
if self._is_ref: if self._is_ref:
from .actor.dp_actor import DataParallelPPOActor # lazy import
self.ref_policy = DataParallelPPOActor( self.ref_policy = DataParallelPPOActor(
config=self.config.ref, config=self.config.ref,
actor_module=self.fsdp_module, actor_module=self.fsdp_module,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from .config import RewardConfig from .config import RewardConfig
from .custom import CustomRewardManager from .function import FunctionRewardManager
__all__ = ["CustomRewardManager", "RewardConfig"] __all__ = ["FunctionRewardManager", "RewardConfig"]
...@@ -15,11 +15,28 @@ ...@@ -15,11 +15,28 @@
Reward config Reward config
""" """
from dataclasses import dataclass import os
from dataclasses import dataclass, field
from typing import Optional
@dataclass @dataclass
class RewardConfig: class RewardConfig:
reward_type: str = "function" reward_type: str = "function"
score_function: str = "math" score_function: Optional[str] = None
score_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True skip_special_tokens: bool = True
"""auto keys"""
score_function_name: Optional[str] = field(default=None, init=False)
def post_init(self):
if self.score_function is not None:
if ":" not in self.score_function:
self.score_function_name = "main"
else:
self.score_function, self.score_function_name = self.score_function.split(":", maxsplit=1)
if os.path.exists(self.score_function):
self.score_function = os.path.abspath(self.score_function)
else:
self.score_function = None
...@@ -12,34 +12,57 @@ ...@@ -12,34 +12,57 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib.util
import os
import sys
from collections import defaultdict from collections import defaultdict
from typing import Callable, Dict, List, Tuple, TypedDict from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ...protocol import DataProto from ...protocol import DataProto
from ...utils.reward_score import math_compute_score, r1v_compute_score
from .config import RewardConfig from .config import RewardConfig
class RewardScore(TypedDict): class RewardScore(TypedDict):
overall: float overall: float
format: float format: Optional[float]
accuracy: float accuracy: Optional[float]
ScoreFunction = Callable[[str, str], RewardScore]
@dataclass
class FunctionRewardManager:
config: RewardConfig
tokenizer: PreTrainedTokenizer
def __post_init__(self):
"""Load score function."""
if self.config.score_function is None:
raise ValueError("Score function is not provided.")
if not os.path.exists(self.config.score_function):
raise FileNotFoundError(f"Score function file {self.config.score_function} not found.")
spec = importlib.util.spec_from_file_location("custom_score_fn", self.config.score_function)
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_score_fn"] = module
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Failed to load score function: {e}")
if not hasattr(module, self.config.score_function_name):
raise AttributeError(f"Module {module} does not have function {self.config.score_function_name}.")
class CustomRewardManager: score_fn: ScoreFunction = getattr(module, self.config.score_function_name)
def __init__(self, tokenizer: PreTrainedTokenizer, config: RewardConfig): print(f"Using score function `{self.config.score_function_name}` from `{self.config.score_function}`.")
self.config = config self.score_fn = partial(score_fn, **self.config.score_function_kwargs)
self.tokenizer = tokenizer
if config.score_function == "math":
self.compute_score: Callable[[str, str], RewardScore] = math_compute_score
elif config.score_function == "r1v":
self.compute_score: Callable[[str, str], RewardScore] = r1v_compute_score
else:
raise NotImplementedError(f"Unknown score function {config.score_function}.")
def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
...@@ -56,7 +79,7 @@ class CustomRewardManager: ...@@ -56,7 +79,7 @@ class CustomRewardManager:
) )
ground_truth = data_item.non_tensor_batch["ground_truth"] ground_truth = data_item.non_tensor_batch["ground_truth"]
score = self.compute_score(response_str, ground_truth) score = self.score_fn(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score["overall"] reward_tensor[i, valid_response_length - 1] = score["overall"]
for key, value in score.items(): for key, value in score.items():
reward_metrics[key].append(value) reward_metrics[key].append(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment