Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
# ignore lm head
layers = find_layers(model)
for name in ['lm_head']:
for name in ["lm_head"]:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
if checkpoint.endswith('.safetensors'):
if checkpoint.endswith(".safetensors"):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
......
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
import torch
import torch.nn as nn
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
return res
......@@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
self.register_buffer("maxq", torch.tensor(0))
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
......@@ -68,7 +67,7 @@ class Quantizer(nn.Module):
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
......@@ -123,13 +122,12 @@ class Quantizer(nn.Module):
try:
import quant_cuda
except:
print('CUDA extension not installed.')
print("CUDA extension not installed.")
# Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__()
if bits not in [2, 3, 4, 8]:
......@@ -142,11 +140,11 @@ class QuantLinear(nn.Module):
groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize
self.register_buffer(
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
dtype=torch.int))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
self.register_buffer('bias', torch.zeros(outfeatures))
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
"qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
)
self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
self.register_buffer("bias", torch.zeros(outfeatures))
self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False
def pack(self, linear, scales, zeros):
......@@ -161,8 +159,10 @@ class QuantLinear(nn.Module):
for idx in range(self.infeatures):
g_idx = idx // self.groupsize
intweight.append(
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
None])
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
:, None
]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
......@@ -271,13 +271,13 @@ class QuantLinear(nn.Module):
return y.reshape(outshape)
def make_quant(module, names, bits, groupsize, name=''):
def make_quant(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
name1 = name + "." + attr if name != "" else attr
if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)
......@@ -9,8 +9,7 @@ def _noop(*args, **kwargs):
@contextmanager
def low_resource_init():
"""This context manager disables weight initialization and sets the default float dtype to half.
"""
"""This context manager disables weight initialization and sets the default float dtype to half."""
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
old_uniform_ = torch.nn.init.uniform_
old_normal_ = torch.nn.init.normal_
......
......@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class TrainerCallback(ABC):
"""
Base callback class. It defines the interface for callbacks.
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
......@@ -40,7 +40,6 @@ class TrainerCallback(ABC):
class MakerCallback(ABC):
def on_loop_start(self) -> None:
pass
......
......@@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
......@@ -42,13 +41,13 @@ class Timer:
self.duration += time() - self.start_time
def reset(self) -> None:
self.duration = 0.
self.duration = 0.0
class ExperienceMakerPerformanceEvaluator(MakerCallback):
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
reward_model_num_params: int) -> None:
def __init__(
self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
......@@ -63,7 +62,7 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
self.make_experience_flop: int = 0
print_rank_0(
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_make_experience_start(self) -> None:
......@@ -110,27 +109,29 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
(self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
self.total_samples * self.world_size
)
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
"Making Experience Performance Summary:\n"
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)
class TrainerPerformanceEvaluator(TrainerCallback):
def __init__(self,
actor_num_params: int,
critic_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1) -> None:
def __init__(
self,
actor_num_params: int,
critic_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1,
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
......@@ -146,7 +147,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
self.learn_flop: int = 0
print_rank_0(
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_episode_start(self, episodes: int) -> None:
......@@ -191,7 +192,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
def on_fit_end(self) -> None:
if self.total_samples == 0:
print_rank_0('No samples are collected, skip trainer performance evaluation')
print_rank_0("No samples are collected, skip trainer performance evaluation")
return
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
......@@ -204,9 +205,10 @@ class TrainerPerformanceEvaluator(TrainerCallback):
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
"Learning Performance Summary:\n"
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)
import asyncio
import copy
import random
from threading import Lock
from typing import Any, List
from typing import List
import ray
import torch
from coati.experience_buffer import ExperienceBuffer
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
class DetachedReplayBuffer:
'''
"""
Detached replay buffer. Share Experience across workers on the same node.
Therefore, a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
......@@ -24,7 +19,7 @@ class DetachedReplayBuffer:
tp_world_size: Number of workers in the same tp group
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
'''
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size
......@@ -34,23 +29,23 @@ class DetachedReplayBuffer:
@torch.no_grad()
def append(self, experience: Experience) -> None:
'''
"""
Expected to be called remotely.
'''
"""
items = split_experience_batch(experience)
self.extend(items)
@torch.no_grad()
def extend(self, items: List[BufferItem]) -> None:
'''
"""
Expected to be called remotely.
'''
"""
self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size:
items = self.batch_collector[:self.sample_batch_size]
items = self.batch_collector[: self.sample_batch_size]
experience = make_experience_batch(items)
self.items.put(experience, block=True)
self.batch_collector = self.batch_collector[self.sample_batch_size:]
self.batch_collector = self.batch_collector[self.sample_batch_size :]
def clear(self) -> None:
# self.items.close()
......
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, List
import ray
import torch
......@@ -15,7 +15,7 @@ from .utils import is_rank_0
class DetachedTrainer(ABC):
'''
"""
Base class for detached rlhf trainers.
'detach' means that the experience maker is detached compared to a normal Trainer.
Please set name attribute during init:
......@@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
def __init__(self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
debug: bool = False) -> None:
"""
def __init__(
self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
debug: bool = False,
) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory
......@@ -67,18 +69,16 @@ class DetachedTrainer(ABC):
def _learn(self, update_steps: int, train_epochs: int) -> None:
data = []
# warmup
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(0)
self._learn_epoch(pbar, data)
self._on_epoch_end(0)
# item is already a batch
dataloader = DataLoader(data,
batch_size=1,
shuffle=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=lambda x: x[0])
dataloader = DataLoader(
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
)
for epoch in range(1, train_epochs):
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(epoch)
self._learn_epoch(pbar, data)
self._on_epoch_end(epoch)
......@@ -104,7 +104,7 @@ class DetachedTrainer(ABC):
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start()
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
self._on_episode_start(i)
self._learn(update_steps, train_epochs)
self._on_update_start()
......
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple
import ray
import torch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.experience_maker import Experience
from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
......@@ -14,27 +13,14 @@ from colossalai.nn.optimizer import HybridAdam
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor
from .utils import (
get_actor_from_args,
get_critic_from_args,
get_model_numel,
get_rank,
get_strategy_from_args,
is_rank_0,
set_dist_env,
state_dict_to,
)
from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={
"buffer_length": 1,
"buffer_append": 1,
"buffer_sample": 1,
"model_io": 1,
"compute": 1
})
@ray.remote(
concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
)
class DetachedPPOTrainer(DetachedTrainer):
'''
"""
Detached Trainer for PPO algorithm
Args:
strategy (Strategy): the strategy to use for training
......@@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
"""
def __init__(
self,
......@@ -92,21 +78,24 @@ class DetachedPPOTrainer(DetachedTrainer):
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
(self.actor, self.actor_optim), (self.critic, self.critic_optim)
)
# configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
super().__init__(experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug)
super().__init__(
experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug,
)
if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
self._update_lora_weights = update_lora_weights
......@@ -115,7 +104,7 @@ class DetachedPPOTrainer(DetachedTrainer):
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if not fully_update:
config['requires_grad_only'] = True
config["requires_grad_only"] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
......@@ -131,7 +120,9 @@ class DetachedPPOTrainer(DetachedTrainer):
target_holder.update_experience_maker.remote(
new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
fully_update=fully_update))
fully_update=fully_update,
)
)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
......@@ -139,7 +130,9 @@ class DetachedPPOTrainer(DetachedTrainer):
target_holder.update_experience_maker.remote(
new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
fully_update=fully_update))
fully_update=fully_update,
)
)
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
......@@ -152,26 +145,24 @@ class DetachedPPOTrainer(DetachedTrainer):
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
values = self.critic(
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
)
critic_loss = self.critic_loss_fn(
values, experience.values, experience.reward, action_mask=experience.action_mask
)
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_model(self.actor, path, only_rank0)
......
import os
import time
import tracemalloc
from copy import deepcopy
from threading import Lock
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import ray
import torch
import torch.nn as nn
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
from coati.experience_buffer.utils import split_experience_batch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
from coati.trainer.strategies.sampler import DistributedSampler
from ray.exceptions import GetTimeoutError
from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .lora_constructor import LoRAConstructor
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
'''
"""
Args:
detached_trainer_name_list: str list to get ray actor handles
strategy:
kl_coef: the coefficient of kl divergence loss
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
'''
"""
def __init__(
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True,
kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs):
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True,
kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs,
):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
......@@ -66,8 +62,9 @@ class ExperienceMakerHolder:
critic_numel = get_model_numel(critic)
initial_model_numel = get_model_numel(initial_model)
reward_model_numel = get_model_numel(reward_model)
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
reward_model_numel)
evaluator = ExperienceMakerPerformanceEvaluator(
actor_numel, critic_numel, initial_model_numel, reward_model_numel
)
callbacks = callbacks + [evaluator]
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
......@@ -89,9 +86,9 @@ class ExperienceMakerHolder:
self._target_idx = 0
if self._debug:
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
if not self._is_fully_initialized:
print(f'[maker{get_rank()}] Waiting for INIT')
print(f"[maker{get_rank()}] Waiting for INIT")
def _get_ready(self):
while not self._fully_initialized():
......@@ -136,7 +133,7 @@ class ExperienceMakerHolder:
self._on_make_experience_end(experience)
self._on_send_start()
if self.buffer_cpu_offload:
experience.to_device('cpu')
experience.to_device("cpu")
self._send_items(experience)
self._on_send_end()
self._on_batch_end()
......@@ -155,7 +152,7 @@ class ExperienceMakerHolder:
if num_steps > 0:
# ignore num epochs
it = iter(dataloader)
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
try:
batch = next(it)
except StopIteration:
......@@ -163,7 +160,7 @@ class ExperienceMakerHolder:
batch = next(it)
self._inference_step(batch)
else:
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
for _ in range(num_epochs):
for batch in dataloader:
self._inference_step(batch)
......@@ -171,22 +168,24 @@ class ExperienceMakerHolder:
self._on_loop_end()
@ray.method(concurrency_group="model_io")
def update_experience_maker(self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None):
'''
called by trainer
chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
fully_update: Set True if you want to sync models when initializing
TODO: load_state_dict integrate with model-sharding strategy
'''
def update_experience_maker(
self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None,
):
"""
called by trainer
chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
fully_update: Set True if you want to sync models when initializing
TODO: load_state_dict integrate with model-sharding strategy
"""
_watch_memory = self._debug
if chunk_start:
if self._debug:
......@@ -202,18 +201,22 @@ class ExperienceMakerHolder:
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
new_actor_state_dict, new_actor_lora_config_dict)
new_actor_state_dict, new_actor_lora_config_dict
)
self.actor_lora_constructor.load_state_dict_increase(
self.experience_maker.actor.model, state_dict_increase)
self.experience_maker.actor.model, state_dict_increase
)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
new_critic_state_dict, new_critic_lora_config_dict)
new_critic_state_dict, new_critic_lora_config_dict
)
self.critic_lora_constructor.load_state_dict_increase(
self.experience_maker.critic, state_dict_increase)
self.experience_maker.critic, state_dict_increase
)
# the lock must be released after both actor and critic being updated
if chunk_end:
......@@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
origin_model = actor.model
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
return new_kwargs
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict
import torch
import torch.nn as nn
from coati.models.lora import LoraLinear
from loralib.layers import LoRALayer
@dataclass
......@@ -17,7 +15,7 @@ class LoRAConfig:
class LoRAConstructor:
'''
"""
Tools for reconstructing a model from a remote LoRA model.
(Transferring only LoRA data costs much less!)
Usage:
......@@ -36,7 +34,7 @@ class LoRAConstructor:
Step 5 (Receiver):
load_state_dict_increase()
'''
"""
def __init__(self):
self.lora_config_dict = None
......@@ -45,10 +43,10 @@ class LoRAConstructor:
self.lora_config_dict = lora_config_dict
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
'''
xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually.
'''
"""
xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually.
"""
if lora_config_dict is not None:
self.register_lora_config(lora_config_dict)
......@@ -56,24 +54,25 @@ class LoRAConstructor:
config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items():
if k.rpartition('.')[-1] == 'lora_A':
if k.rpartition(".")[-1] == "lora_A":
lora_A = v
layer_prefix = k.rpartition('.')[0]
elif k.rpartition('.')[-1] == 'lora_B':
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
layer_prefix = k.rpartition(".")[0]
elif k.rpartition(".")[-1] == "lora_B":
assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
layer_prefix_2, config = next(config_iter)
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config)
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None
else:
raise ValueError('unexpected key')
raise ValueError("unexpected key")
return state_dict_increase
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w):
return w.T if config.fan_in_fan_out else w
if config.r > 0:
scaling = config.lora_alpha / config.r
weight_data_increase = T(lora_B @ lora_A) * scaling
......@@ -81,21 +80,21 @@ class LoRAConstructor:
return 0
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
'''
"""
The final reconstruction step
'''
"""
# naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
@staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
'''
"""
if keep_non_lora, also return non_lora state_dict
'''
"""
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict.items():
if 'lora_A' in k or 'lora_B' in k:
if "lora_A" in k or "lora_B" in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
......@@ -106,17 +105,19 @@ class LoRAConstructor:
@staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
'''
"""
extract LoraLinear model.
return OrderedDict(): name -> LoRAConfig
'''
"""
lora_config_dict = OrderedDict()
for name, child in model.named_modules():
if isinstance(child, LoraLinear):
lora_config_dict[name] = LoRAConfig(r=child.r,
lora_alpha=child.lora_alpha,
lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out)
lora_config_dict[name] = LoRAConfig(
r=child.r,
lora_alpha=child.lora_alpha,
lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out,
)
return lora_config_dict
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict
import torch
import torch.distributed as dist
......@@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
def is_rank_0() -> bool:
......@@ -26,13 +26,13 @@ def get_world_size() -> int:
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
if model == "gpt2":
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'bloom':
elif model == "bloom":
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'opt':
elif model == "opt":
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'llama':
elif model == "llama":
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
else:
raise ValueError(f'Unsupported actor model "{model}"')
......@@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
if model == "gpt2":
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'bloom':
elif model == "bloom":
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'opt':
elif model == "opt":
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'llama':
elif model == "llama":
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{model}"')
......@@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
if model == 'gpt2':
if model == "gpt2":
reward_model = GPTRM(pretrained=pretrained, config=config)
elif model == 'bloom':
elif model == "bloom":
reward_model = BLOOMRM(pretrained=pretrained, config=config)
elif model == 'opt':
elif model == "opt":
reward_model = OPTRM(pretrained=pretrained, config=config)
elif model == 'llama':
elif model == "llama":
reward_model = LlamaRM(pretrained=pretrained, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
......@@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
def get_strategy_from_args(strategy: str):
if strategy == 'ddp':
if strategy == "ddp":
strategy_ = DDPStrategy()
elif strategy == 'colossalai_gemini':
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif strategy == 'colossalai_gemini_cpu':
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
elif strategy == 'colossalai_zero2_cpu':
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
def get_tokenizer_from_args(model: str, **kwargs):
if model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
elif model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif model == 'opt':
if model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
elif model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif model == 'llama':
elif model == "llama":
pretrain_path = kwargs["pretrain"]
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
else:
......@@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
def set_dist_env(env_info: Dict[str, str]):
os.environ["RANK"] = env_info['rank']
os.environ["LOCAL_RANK"] = env_info['local_rank']
os.environ["WORLD_SIZE"] = env_info['world_size']
os.environ['MASTER_PORT'] = env_info['master_port']
os.environ['MASTER_ADDR'] = env_info['master_addr']
os.environ["RANK"] = env_info["rank"]
os.environ["LOCAL_RANK"] = env_info["local_rank"]
os.environ["WORLD_SIZE"] = env_info["world_size"]
os.environ["MASTER_PORT"] = env_info["master_port"]
os.environ["MASTER_ADDR"] = env_info["master_addr"]
def get_model_numel(model: nn.Module) -> int:
......@@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
return target_receivers
def state_dict_to(state_dict: Dict[str, Any],
dtype: torch.dtype = torch.float16,
device: torch.device = torch.device('cpu')):
'''
keep state_dict intact
'''
def state_dict_to(
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
):
"""
keep state_dict intact
"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)
......
......@@ -3,8 +3,4 @@ from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = [
'SLTrainer', 'OnPolicyTrainer',
'RewardModelTrainer', 'SFTTrainer',
'PPOTrainer'
]
__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
......@@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def __init__(self,
strategy: Strategy,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = []) -> None:
def __init__(
self,
strategy: Strategy,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = [],
) -> None:
super().__init__()
self.strategy = strategy
self.data_buffer = data_buffer
......
......@@ -2,4 +2,4 @@ from .base import Callback
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
......@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class Callback(ABC):
"""
Base callback class. It defines the interface for callbacks.
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
......
......@@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
def divide(x: float, y: float) -> float:
if y == 0:
return float('inf')
elif y == float('inf'):
return float('nan')
return float("inf")
elif y == float("inf"):
return float("nan")
return x / y
......@@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
......@@ -52,7 +51,7 @@ class Timer:
self.start_time = None
def reset(self) -> None:
self.duration = 0.
self.duration = 0.0
class PerformanceEvaluator(Callback):
......@@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(self,
actor_num_params: int,
critic_num_params: int,
initial_model_num_params: int,
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0) -> None:
def __init__(
self,
actor_num_params: int,
critic_num_params: int,
initial_model_num_params: int,
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0,
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
......@@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback):
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
avg_make_experience_throughput = self.make_experience_num_samples * \
self.world_size / (avg_make_experience_duration + 1e-12)
avg_make_experience_throughput = (
self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
......@@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback):
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
f'Performance summary:\n'
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
f"Performance summary:\n"
+ f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
+ f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
)
......@@ -36,34 +36,35 @@ class SaveCheckpoint(Callback):
"""
def __init__(self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None) -> None:
def __init__(
self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None,
) -> None:
super().__init__()
self.path = os.path.join(path, 'checkpoint')
self.path = os.path.join(path, "checkpoint")
self.interval = interval
self.strategy = strategy
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
base_path = os.path.join(self.path, f'episode_{episode}')
base_path = os.path.join(self.path, f"episode_{episode}")
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
model_path = os.path.join(base_path, f'{model}.pt')
model_path = os.path.join(base_path, f"{model}.pt")
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
......@@ -71,5 +72,5 @@ class SaveCheckpoint(Callback):
continue
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank()
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
......@@ -8,7 +8,7 @@ from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from colossalai.utils import get_current_device
......@@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
return new_kwargs
......@@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer):
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: nn.Module,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs
) -> None:
def __init__(
self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: nn.Module,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs,
) -> None:
if isinstance(strategy, GeminiStrategy):
assert not offload_inference_models, \
"GeminiPlugin is not compatible with manual model.to('cpu')"
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(
strategy, data_buffer,
sample_buffer, dataloader_pin_memory,
callbacks
)
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
......@@ -130,18 +126,16 @@ class PPOTrainer(OnPolicyTrainer):
num_actions = experience.action_mask.size(1)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
ptx_log_probs = self.actor(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
......@@ -149,24 +143,23 @@ class PPOTrainer(OnPolicyTrainer):
self.actor_optim.zero_grad()
# value loss
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
values = self.critic(
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
)
critic_loss = self.critic_loss_fn(
values, experience.values, experience.reward, action_mask=experience.action_mask
)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'reward': experience.reward.mean().item()}
return {"reward": experience.reward.mean().item()}
def _learn(self, update_step: int):
if self.offload_inference_models:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
......@@ -178,11 +171,7 @@ class PPOTrainer(OnPolicyTrainer):
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm(
self.dataloader,
desc=f'Train epoch [{update_step + 1}]',
disable=not is_rank_0()
)
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
......
......@@ -62,18 +62,15 @@ class RewardModelTrainer(SLTrainer):
if is_rank_0():
log = pd.DataFrame(
[[(epoch + 1) * len(self.train_dataloader),
self.loss.item(), self.dist, self.acc]],
columns=['step', 'loss', 'dist', 'acc']
[[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
columns=["step", "loss", "dist", "acc"],
)
log.to_csv('log.csv', mode='a', header=False, index=False)
log.to_csv("log.csv", mode="a", header=False, index=False)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0()
len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
)
cnt = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
......@@ -93,10 +90,7 @@ class RewardModelTrainer(SLTrainer):
step_bar.update()
step_bar.close()
def _before_fit(self,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader):
def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
......@@ -104,7 +98,7 @@ class RewardModelTrainer(SLTrainer):
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
super()._before_fit()
self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment