Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int): ...@@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
# ignore lm head # ignore lm head
layers = find_layers(model) layers = find_layers(model)
for name in ['lm_head']: for name in ["lm_head"]:
if name in layers: if name in layers:
del layers[name] del layers[name]
make_quant(model, layers, wbits, groupsize) make_quant(model, layers, wbits, groupsize)
if checkpoint.endswith('.safetensors'): if checkpoint.endswith(".safetensors"):
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint)) model.load_state_dict(safe_load(checkpoint))
else: else:
model.load_state_dict(torch.load(checkpoint)) model.load_state_dict(torch.load(checkpoint))
......
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py # copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
import torch
import torch.nn as nn import torch.nn as nn
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers: if type(module) in layers:
return {name: module} return {name: module}
res = {} res = {}
for name1, child in module.named_children(): for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
return res return res
...@@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq): ...@@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
class Quantizer(nn.Module): class Quantizer(nn.Module):
def __init__(self, shape=1): def __init__(self, shape=1):
super(Quantizer, self).__init__() super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0)) self.register_buffer("maxq", torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape)) self.register_buffer("scale", torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape)) self.register_buffer("zero", torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8): def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
self.maxq = torch.tensor(2**bits - 1) self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel self.perchannel = perchannel
self.sym = sym self.sym = sym
...@@ -68,7 +67,7 @@ class Quantizer(nn.Module): ...@@ -68,7 +67,7 @@ class Quantizer(nn.Module):
self.zero = torch.round(-xmin / self.scale) self.zero = torch.round(-xmin / self.scale)
if self.mse: if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev) best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)): for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid p = 1 - i / self.grid
xmin1 = p * xmin xmin1 = p * xmin
...@@ -123,13 +122,12 @@ class Quantizer(nn.Module): ...@@ -123,13 +122,12 @@ class Quantizer(nn.Module):
try: try:
import quant_cuda import quant_cuda
except: except:
print('CUDA extension not installed.') print("CUDA extension not installed.")
# Assumes layer is perfectly divisible into 256 * 256 blocks # Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module): class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures): def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__() super().__init__()
if bits not in [2, 3, 4, 8]: if bits not in [2, 3, 4, 8]:
...@@ -142,11 +140,11 @@ class QuantLinear(nn.Module): ...@@ -142,11 +140,11 @@ class QuantLinear(nn.Module):
groupsize = groupsize if groupsize != -1 else infeatures groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize self.groupsize = groupsize
self.register_buffer( self.register_buffer(
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), "qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
dtype=torch.int)) )
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
self.register_buffer('bias', torch.zeros(outfeatures)) self.register_buffer("bias", torch.zeros(outfeatures))
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False self._initialized_quant_state = False
def pack(self, linear, scales, zeros): def pack(self, linear, scales, zeros):
...@@ -161,8 +159,10 @@ class QuantLinear(nn.Module): ...@@ -161,8 +159,10 @@ class QuantLinear(nn.Module):
for idx in range(self.infeatures): for idx in range(self.infeatures):
g_idx = idx // self.groupsize g_idx = idx // self.groupsize
intweight.append( intweight.append(
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
None]) :, None
]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous() intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32) intweight = intweight.numpy().astype(np.uint32)
...@@ -271,13 +271,13 @@ class QuantLinear(nn.Module): ...@@ -271,13 +271,13 @@ class QuantLinear(nn.Module):
return y.reshape(outshape) return y.reshape(outshape)
def make_quant(module, names, bits, groupsize, name=''): def make_quant(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear): if isinstance(module, QuantLinear):
return return
for attr in dir(module): for attr in dir(module):
tmp = getattr(module, attr) tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr name1 = name + "." + attr if name != "" else attr
if name1 in names: if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)) setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children(): for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)
...@@ -9,8 +9,7 @@ def _noop(*args, **kwargs): ...@@ -9,8 +9,7 @@ def _noop(*args, **kwargs):
@contextmanager @contextmanager
def low_resource_init(): def low_resource_init():
"""This context manager disables weight initialization and sets the default float dtype to half. """This context manager disables weight initialization and sets the default float dtype to half."""
"""
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_ old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
old_uniform_ = torch.nn.init.uniform_ old_uniform_ = torch.nn.init.uniform_
old_normal_ = torch.nn.init.normal_ old_normal_ = torch.nn.init.normal_
......
...@@ -5,7 +5,7 @@ from coati.experience_maker import Experience ...@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class TrainerCallback(ABC): class TrainerCallback(ABC):
""" """
Base callback class. It defines the interface for callbacks. Base callback class. It defines the interface for callbacks.
""" """
def on_fit_start(self) -> None: def on_fit_start(self) -> None:
...@@ -40,7 +40,6 @@ class TrainerCallback(ABC): ...@@ -40,7 +40,6 @@ class TrainerCallback(ABC):
class MakerCallback(ABC): class MakerCallback(ABC):
def on_loop_start(self) -> None: def on_loop_start(self) -> None:
pass pass
......
...@@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: ...@@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer: class Timer:
def __init__(self) -> None: def __init__(self) -> None:
self.start_time: Optional[float] = None self.start_time: Optional[float] = None
self.duration: float = 0. self.duration: float = 0.0
def start(self) -> None: def start(self) -> None:
self.start_time = time() self.start_time = time()
...@@ -42,13 +41,13 @@ class Timer: ...@@ -42,13 +41,13 @@ class Timer:
self.duration += time() - self.start_time self.duration += time() - self.start_time
def reset(self) -> None: def reset(self) -> None:
self.duration = 0. self.duration = 0.0
class ExperienceMakerPerformanceEvaluator(MakerCallback): class ExperienceMakerPerformanceEvaluator(MakerCallback):
def __init__(
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
reward_model_num_params: int) -> None: ) -> None:
super().__init__() super().__init__()
self.world_size = get_world_size() self.world_size = get_world_size()
self.actor_num_params = actor_num_params self.actor_num_params = actor_num_params
...@@ -63,7 +62,7 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback): ...@@ -63,7 +62,7 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
self.make_experience_flop: int = 0 self.make_experience_flop: int = 0
print_rank_0( print_rank_0(
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}' f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
) )
def on_make_experience_start(self) -> None: def on_make_experience_start(self) -> None:
...@@ -110,27 +109,29 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback): ...@@ -110,27 +109,29 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12) avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size) avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \ avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
(self.total_samples * self.world_size) self.total_samples * self.world_size
)
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size) avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0( print_rank_0(
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' "Making Experience Performance Summary:\n"
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' + f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n' + f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n' + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n' + f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
) )
class TrainerPerformanceEvaluator(TrainerCallback): class TrainerPerformanceEvaluator(TrainerCallback):
def __init__(
def __init__(self, self,
actor_num_params: int, actor_num_params: int,
critic_num_params: int, critic_num_params: int,
enable_grad_checkpoint: bool = False, enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1) -> None: ignore_first_episodes: int = 1,
) -> None:
super().__init__() super().__init__()
self.world_size = get_world_size() self.world_size = get_world_size()
self.actor_num_params = actor_num_params self.actor_num_params = actor_num_params
...@@ -146,7 +147,7 @@ class TrainerPerformanceEvaluator(TrainerCallback): ...@@ -146,7 +147,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
self.learn_flop: int = 0 self.learn_flop: int = 0
print_rank_0( print_rank_0(
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}' f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
) )
def on_episode_start(self, episodes: int) -> None: def on_episode_start(self, episodes: int) -> None:
...@@ -191,7 +192,7 @@ class TrainerPerformanceEvaluator(TrainerCallback): ...@@ -191,7 +192,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
def on_fit_end(self) -> None: def on_fit_end(self) -> None:
if self.total_samples == 0: if self.total_samples == 0:
print_rank_0('No samples are collected, skip trainer performance evaluation') print_rank_0("No samples are collected, skip trainer performance evaluation")
return return
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size) avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size) avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
...@@ -204,9 +205,10 @@ class TrainerPerformanceEvaluator(TrainerCallback): ...@@ -204,9 +205,10 @@ class TrainerPerformanceEvaluator(TrainerCallback):
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size) avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0( print_rank_0(
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' "Learning Performance Summary:\n"
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' + f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n' + f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n' + f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
) )
import asyncio from typing import List
import copy
import random
from threading import Lock
from typing import Any, List
import ray
import torch import torch
from coati.experience_buffer import ExperienceBuffer
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience from coati.experience_maker.base import Experience
# from torch.multiprocessing import Queue # from torch.multiprocessing import Queue
from ray.util.queue import Queue from ray.util.queue import Queue
class DetachedReplayBuffer: class DetachedReplayBuffer:
''' """
Detached replay buffer. Share Experience across workers on the same node. Detached replay buffer. Share Experience across workers on the same node.
Therefore, a trainer node is expected to have only one instance. Therefore, a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely. It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
...@@ -24,7 +19,7 @@ class DetachedReplayBuffer: ...@@ -24,7 +19,7 @@ class DetachedReplayBuffer:
tp_world_size: Number of workers in the same tp group tp_world_size: Number of workers in the same tp group
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0. limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
''' """
def __init__(self, sample_batch_size: int, limit: int = 0) -> None: def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size self.sample_batch_size = sample_batch_size
...@@ -34,23 +29,23 @@ class DetachedReplayBuffer: ...@@ -34,23 +29,23 @@ class DetachedReplayBuffer:
@torch.no_grad() @torch.no_grad()
def append(self, experience: Experience) -> None: def append(self, experience: Experience) -> None:
''' """
Expected to be called remotely. Expected to be called remotely.
''' """
items = split_experience_batch(experience) items = split_experience_batch(experience)
self.extend(items) self.extend(items)
@torch.no_grad() @torch.no_grad()
def extend(self, items: List[BufferItem]) -> None: def extend(self, items: List[BufferItem]) -> None:
''' """
Expected to be called remotely. Expected to be called remotely.
''' """
self.batch_collector.extend(items) self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size: while len(self.batch_collector) >= self.sample_batch_size:
items = self.batch_collector[:self.sample_batch_size] items = self.batch_collector[: self.sample_batch_size]
experience = make_experience_batch(items) experience = make_experience_batch(items)
self.items.put(experience, block=True) self.items.put(experience, block=True)
self.batch_collector = self.batch_collector[self.sample_batch_size:] self.batch_collector = self.batch_collector[self.sample_batch_size :]
def clear(self) -> None: def clear(self) -> None:
# self.items.close() # self.items.close()
......
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Union from typing import Any, Dict, List
import ray import ray
import torch import torch
...@@ -15,7 +15,7 @@ from .utils import is_rank_0 ...@@ -15,7 +15,7 @@ from .utils import is_rank_0
class DetachedTrainer(ABC): class DetachedTrainer(ABC):
''' """
Base class for detached rlhf trainers. Base class for detached rlhf trainers.
'detach' means that the experience maker is detached compared to a normal Trainer. 'detach' means that the experience maker is detached compared to a normal Trainer.
Please set name attribute during init: Please set name attribute during init:
...@@ -28,15 +28,17 @@ class DetachedTrainer(ABC): ...@@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating generate_kwargs (dict, optional): the kwargs to use while model generating
''' """
def __init__(self, def __init__(
experience_maker_holder_name_list: List[str], self,
train_batch_size: int = 8, experience_maker_holder_name_list: List[str],
buffer_limit: int = 0, train_batch_size: int = 8,
dataloader_pin_memory: bool = True, buffer_limit: int = 0,
callbacks: List[TrainerCallback] = [], dataloader_pin_memory: bool = True,
debug: bool = False) -> None: callbacks: List[TrainerCallback] = [],
debug: bool = False,
) -> None:
super().__init__() super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit) self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory self.dataloader_pin_memory = dataloader_pin_memory
...@@ -67,18 +69,16 @@ class DetachedTrainer(ABC): ...@@ -67,18 +69,16 @@ class DetachedTrainer(ABC):
def _learn(self, update_steps: int, train_epochs: int) -> None: def _learn(self, update_steps: int, train_epochs: int) -> None:
data = [] data = []
# warmup # warmup
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0()) pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(0) self._on_epoch_start(0)
self._learn_epoch(pbar, data) self._learn_epoch(pbar, data)
self._on_epoch_end(0) self._on_epoch_end(0)
# item is already a batch # item is already a batch
dataloader = DataLoader(data, dataloader = DataLoader(
batch_size=1, data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
shuffle=True, )
pin_memory=self.dataloader_pin_memory,
collate_fn=lambda x: x[0])
for epoch in range(1, train_epochs): for epoch in range(1, train_epochs):
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0()) pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(epoch) self._on_epoch_start(epoch)
self._learn_epoch(pbar, data) self._learn_epoch(pbar, data)
self._on_epoch_end(epoch) self._on_epoch_end(epoch)
...@@ -104,7 +104,7 @@ class DetachedTrainer(ABC): ...@@ -104,7 +104,7 @@ class DetachedTrainer(ABC):
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None: def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start() self._on_fit_start()
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()): for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
self._on_episode_start(i) self._on_episode_start(i)
self._learn(update_steps, train_epochs) self._learn(update_steps, train_epochs)
self._on_update_start() self._on_update_start()
......
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Tuple
import ray import ray
import torch import torch
from coati.experience_maker import Experience, NaiveExperienceMaker from coati.experience_maker import Experience
from coati.models.base import Actor, Critic from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.callbacks import Callback from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from torch.optim import Adam from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
...@@ -14,27 +13,14 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -14,27 +13,14 @@ from colossalai.nn.optimizer import HybridAdam
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor from .lora_constructor import LoRAConstructor
from .utils import ( from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
get_actor_from_args,
get_critic_from_args,
get_model_numel,
get_rank,
get_strategy_from_args,
is_rank_0,
set_dist_env,
state_dict_to,
)
@ray.remote(concurrency_groups={ @ray.remote(
"buffer_length": 1, concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
"buffer_append": 1, )
"buffer_sample": 1,
"model_io": 1,
"compute": 1
})
class DetachedPPOTrainer(DetachedTrainer): class DetachedPPOTrainer(DetachedTrainer):
''' """
Detached Trainer for PPO algorithm Detached Trainer for PPO algorithm
Args: Args:
strategy (Strategy): the strategy to use for training strategy (Strategy): the strategy to use for training
...@@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer): ...@@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating generate_kwargs (dict, optional): the kwargs to use while model generating
''' """
def __init__( def __init__(
self, self,
...@@ -92,21 +78,24 @@ class DetachedPPOTrainer(DetachedTrainer): ...@@ -92,21 +78,24 @@ class DetachedPPOTrainer(DetachedTrainer):
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7) self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7) self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \ (self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) (self.actor, self.actor_optim), (self.critic, self.critic_optim)
)
# configure trainer # configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip) self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip) self.critic_loss_fn = ValueLoss(value_clip)
super().__init__(experience_maker_holder_name_list, super().__init__(
train_batch_size=train_batch_size, experience_maker_holder_name_list,
buffer_limit=buffer_limit, train_batch_size=train_batch_size,
dataloader_pin_memory=dataloader_pin_memory, buffer_limit=buffer_limit,
callbacks=callbacks, dataloader_pin_memory=dataloader_pin_memory,
debug=debug) callbacks=callbacks,
debug=debug,
)
if self._debug: if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}') print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
self._update_lora_weights = update_lora_weights self._update_lora_weights = update_lora_weights
...@@ -115,7 +104,7 @@ class DetachedPPOTrainer(DetachedTrainer): ...@@ -115,7 +104,7 @@ class DetachedPPOTrainer(DetachedTrainer):
def _update_remote_makers(self, fully_update: bool = False, **config): def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties # TODO: balance duties
if not fully_update: if not fully_update:
config['requires_grad_only'] = True config["requires_grad_only"] = True
self.update_target_holder_list() self.update_target_holder_list()
# mark start, ensure order # mark start, ensure order
tasks = [] tasks = []
...@@ -131,7 +120,9 @@ class DetachedPPOTrainer(DetachedTrainer): ...@@ -131,7 +120,9 @@ class DetachedPPOTrainer(DetachedTrainer):
target_holder.update_experience_maker.remote( target_holder.update_experience_maker.remote(
new_actor_state_dict=state_dict_shard, new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor), new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
fully_update=fully_update)) fully_update=fully_update,
)
)
# sending loop # sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config): for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
for target_holder in self.target_holder_list: for target_holder in self.target_holder_list:
...@@ -139,7 +130,9 @@ class DetachedPPOTrainer(DetachedTrainer): ...@@ -139,7 +130,9 @@ class DetachedPPOTrainer(DetachedTrainer):
target_holder.update_experience_maker.remote( target_holder.update_experience_maker.remote(
new_critic_state_dict=state_dict_shard, new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic), new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
fully_update=fully_update)) fully_update=fully_update,
)
)
ray.get(tasks) ray.get(tasks)
# mark end # mark end
for target_holder in self.target_holder_list: for target_holder in self.target_holder_list:
...@@ -152,26 +145,24 @@ class DetachedPPOTrainer(DetachedTrainer): ...@@ -152,26 +145,24 @@ class DetachedPPOTrainer(DetachedTrainer):
num_actions = experience.action_mask.size(1) num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs, actor_loss = self.actor_loss_fn(
experience.action_log_probs, action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
experience.advantages, )
action_mask=experience.action_mask)
self.strategy.backward(actor_loss, self.actor, self.actor_optim) self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim) self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad() self.actor_optim.zero_grad()
values = self.critic(experience.sequences, values = self.critic(
action_mask=experience.action_mask, experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
attention_mask=experience.attention_mask) )
critic_loss = self.critic_loss_fn(values, critic_loss = self.critic_loss_fn(
experience.values, values, experience.values, experience.reward, action_mask=experience.action_mask
experience.reward, )
action_mask=experience.action_mask)
self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim) self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_model(self.actor, path, only_rank0) self.strategy.save_model(self.actor, path, only_rank0)
......
import os import os
import time import time
import tracemalloc import tracemalloc
from copy import deepcopy
from threading import Lock from threading import Lock
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import ray import ray
import torch import torch
import torch.nn as nn from coati.experience_buffer.utils import split_experience_batch
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel from coati.models.base import Actor, Critic, RewardModel
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy from coati.trainer.strategies import Strategy
from coati.trainer.strategies.sampler import DistributedSampler
from ray.exceptions import GetTimeoutError
from torch import Tensor from torch import Tensor
from tqdm import tqdm from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .lora_constructor import LoRAConstructor from .lora_constructor import LoRAConstructor
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder: class ExperienceMakerHolder:
''' """
Args: Args:
detached_trainer_name_list: str list to get ray actor handles detached_trainer_name_list: str list to get ray actor handles
strategy: strategy:
kl_coef: the coefficient of kl divergence loss kl_coef: the coefficient of kl divergence loss
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models. sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
''' """
def __init__( def __init__(
self, self,
detached_trainer_name_list: List[str], detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy], strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model) # a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None, env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False, sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True, buffer_cpu_offload: bool = True,
kl_coef: float = 0.1, kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [], callbacks: List[MakerCallback] = [],
eval_performance: bool = False, eval_performance: bool = False,
debug: bool = False, debug: bool = False,
update_lora_weights: bool = False, update_lora_weights: bool = False,
**generate_kwargs): **generate_kwargs,
):
# set environment variables # set environment variables
if env_info: if env_info:
set_dist_env(env_info=env_info) set_dist_env(env_info=env_info)
...@@ -66,8 +62,9 @@ class ExperienceMakerHolder: ...@@ -66,8 +62,9 @@ class ExperienceMakerHolder:
critic_numel = get_model_numel(critic) critic_numel = get_model_numel(critic)
initial_model_numel = get_model_numel(initial_model) initial_model_numel = get_model_numel(initial_model)
reward_model_numel = get_model_numel(reward_model) reward_model_numel = get_model_numel(reward_model)
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, evaluator = ExperienceMakerPerformanceEvaluator(
reward_model_numel) actor_numel, critic_numel, initial_model_numel, reward_model_numel
)
callbacks = callbacks + [evaluator] callbacks = callbacks + [evaluator]
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model) actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
...@@ -89,9 +86,9 @@ class ExperienceMakerHolder: ...@@ -89,9 +86,9 @@ class ExperienceMakerHolder:
self._target_idx = 0 self._target_idx = 0
if self._debug: if self._debug:
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}') print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
if not self._is_fully_initialized: if not self._is_fully_initialized:
print(f'[maker{get_rank()}] Waiting for INIT') print(f"[maker{get_rank()}] Waiting for INIT")
def _get_ready(self): def _get_ready(self):
while not self._fully_initialized(): while not self._fully_initialized():
...@@ -136,7 +133,7 @@ class ExperienceMakerHolder: ...@@ -136,7 +133,7 @@ class ExperienceMakerHolder:
self._on_make_experience_end(experience) self._on_make_experience_end(experience)
self._on_send_start() self._on_send_start()
if self.buffer_cpu_offload: if self.buffer_cpu_offload:
experience.to_device('cpu') experience.to_device("cpu")
self._send_items(experience) self._send_items(experience)
self._on_send_end() self._on_send_end()
self._on_batch_end() self._on_batch_end()
...@@ -155,7 +152,7 @@ class ExperienceMakerHolder: ...@@ -155,7 +152,7 @@ class ExperienceMakerHolder:
if num_steps > 0: if num_steps > 0:
# ignore num epochs # ignore num epochs
it = iter(dataloader) it = iter(dataloader)
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()): for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
try: try:
batch = next(it) batch = next(it)
except StopIteration: except StopIteration:
...@@ -163,7 +160,7 @@ class ExperienceMakerHolder: ...@@ -163,7 +160,7 @@ class ExperienceMakerHolder:
batch = next(it) batch = next(it)
self._inference_step(batch) self._inference_step(batch)
else: else:
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar: with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
for _ in range(num_epochs): for _ in range(num_epochs):
for batch in dataloader: for batch in dataloader:
self._inference_step(batch) self._inference_step(batch)
...@@ -171,22 +168,24 @@ class ExperienceMakerHolder: ...@@ -171,22 +168,24 @@ class ExperienceMakerHolder:
self._on_loop_end() self._on_loop_end()
@ray.method(concurrency_group="model_io") @ray.method(concurrency_group="model_io")
def update_experience_maker(self, def update_experience_maker(
new_actor_state_dict: Dict[str, Any] = None, self,
new_actor_lora_config_dict: Dict[str, Any] = None, new_actor_state_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None, new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None, new_critic_state_dict: Dict[str, Any] = None,
fully_update: bool = False, new_critic_lora_config_dict: Dict[str, Any] = None,
chunk_start: bool = None, fully_update: bool = False,
chunk_end: bool = None): chunk_start: bool = None,
''' chunk_end: bool = None,
called by trainer ):
chunk_start: Set True at the first call. Before sending state_dict calls """
chunk_end: Set True at the last call. After sending state_dict calls. called by trainer
fully_update: Set True if you want to sync models when initializing chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
TODO: load_state_dict integrate with model-sharding strategy fully_update: Set True if you want to sync models when initializing
'''
TODO: load_state_dict integrate with model-sharding strategy
"""
_watch_memory = self._debug _watch_memory = self._debug
if chunk_start: if chunk_start:
if self._debug: if self._debug:
...@@ -202,18 +201,22 @@ class ExperienceMakerHolder: ...@@ -202,18 +201,22 @@ class ExperienceMakerHolder:
else: else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device()) new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase( state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
new_actor_state_dict, new_actor_lora_config_dict) new_actor_state_dict, new_actor_lora_config_dict
)
self.actor_lora_constructor.load_state_dict_increase( self.actor_lora_constructor.load_state_dict_increase(
self.experience_maker.actor.model, state_dict_increase) self.experience_maker.actor.model, state_dict_increase
)
if new_critic_state_dict is not None: if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update: if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else: else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device()) new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase( state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
new_critic_state_dict, new_critic_lora_config_dict) new_critic_state_dict, new_critic_lora_config_dict
)
self.critic_lora_constructor.load_state_dict_increase( self.critic_lora_constructor.load_state_dict_increase(
self.experience_maker.critic, state_dict_increase) self.experience_maker.critic, state_dict_increase
)
# the lock must be released after both actor and critic being updated # the lock must be released after both actor and critic being updated
if chunk_end: if chunk_end:
...@@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None: ...@@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
origin_model = actor.model origin_model = actor.model
new_kwargs = {**generate_kwargs} new_kwargs = {**generate_kwargs}
# use huggingface models method directly # use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'): if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
return new_kwargs return new_kwargs
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional from typing import Any, Dict
import torch
import torch.nn as nn import torch.nn as nn
from coati.models.lora import LoraLinear from coati.models.lora import LoraLinear
from loralib.layers import LoRALayer
@dataclass @dataclass
...@@ -17,7 +15,7 @@ class LoRAConfig: ...@@ -17,7 +15,7 @@ class LoRAConfig:
class LoRAConstructor: class LoRAConstructor:
''' """
Tools for reconstructing a model from a remote LoRA model. Tools for reconstructing a model from a remote LoRA model.
(Transferring only LoRA data costs much less!) (Transferring only LoRA data costs much less!)
Usage: Usage:
...@@ -36,7 +34,7 @@ class LoRAConstructor: ...@@ -36,7 +34,7 @@ class LoRAConstructor:
Step 5 (Receiver): Step 5 (Receiver):
load_state_dict_increase() load_state_dict_increase()
''' """
def __init__(self): def __init__(self):
self.lora_config_dict = None self.lora_config_dict = None
...@@ -45,10 +43,10 @@ class LoRAConstructor: ...@@ -45,10 +43,10 @@ class LoRAConstructor:
self.lora_config_dict = lora_config_dict self.lora_config_dict = lora_config_dict
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]): def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
''' """
xxx.lora_A, xxx.lora_B -->> xxx.weight xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually. Warning: the xxx.weight here is the increment actually.
''' """
if lora_config_dict is not None: if lora_config_dict is not None:
self.register_lora_config(lora_config_dict) self.register_lora_config(lora_config_dict)
...@@ -56,24 +54,25 @@ class LoRAConstructor: ...@@ -56,24 +54,25 @@ class LoRAConstructor:
config_iter = iter(self.lora_config_dict.items()) config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items(): for k, v in state_dict_lora.items():
if k.rpartition('.')[-1] == 'lora_A': if k.rpartition(".")[-1] == "lora_A":
lora_A = v lora_A = v
layer_prefix = k.rpartition('.')[0] layer_prefix = k.rpartition(".")[0]
elif k.rpartition('.')[-1] == 'lora_B': elif k.rpartition(".")[-1] == "lora_B":
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
layer_prefix_2, config = next(config_iter) layer_prefix_2, config = next(config_iter)
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config) weight_data_increase = self._compute(lora_A, lora_B, config)
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None lora_A, lora_B, layer_prefix = None, None, None
else: else:
raise ValueError('unexpected key') raise ValueError("unexpected key")
return state_dict_increase return state_dict_increase
def _compute(self, lora_A, lora_B, config=LoRAConfig()): def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w): def T(w):
return w.T if config.fan_in_fan_out else w return w.T if config.fan_in_fan_out else w
if config.r > 0: if config.r > 0:
scaling = config.lora_alpha / config.r scaling = config.lora_alpha / config.r
weight_data_increase = T(lora_B @ lora_A) * scaling weight_data_increase = T(lora_B @ lora_A) * scaling
...@@ -81,21 +80,21 @@ class LoRAConstructor: ...@@ -81,21 +80,21 @@ class LoRAConstructor:
return 0 return 0
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]): def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
''' """
The final reconstruction step The final reconstruction step
''' """
# naive approach # naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False) model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
@staticmethod @staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
''' """
if keep_non_lora, also return non_lora state_dict if keep_non_lora, also return non_lora state_dict
''' """
state_dict_lora = OrderedDict() state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict() state_dict_non_lora = OrderedDict()
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'lora_A' in k or 'lora_B' in k: if "lora_A" in k or "lora_B" in k:
state_dict_lora[k] = v state_dict_lora[k] = v
elif keep_non_lora: elif keep_non_lora:
state_dict_non_lora[k] = v state_dict_non_lora[k] = v
...@@ -106,17 +105,19 @@ class LoRAConstructor: ...@@ -106,17 +105,19 @@ class LoRAConstructor:
@staticmethod @staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]: def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
''' """
extract LoraLinear model. extract LoraLinear model.
return OrderedDict(): name -> LoRAConfig return OrderedDict(): name -> LoRAConfig
''' """
lora_config_dict = OrderedDict() lora_config_dict = OrderedDict()
for name, child in model.named_modules(): for name, child in model.named_modules():
if isinstance(child, LoraLinear): if isinstance(child, LoraLinear):
lora_config_dict[name] = LoRAConfig(r=child.r, lora_config_dict[name] = LoRAConfig(
lora_alpha=child.lora_alpha, r=child.r,
lora_dropout=child.lora_dropout, lora_alpha=child.lora_alpha,
fan_in_fan_out=child.fan_in_fan_out) lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out,
)
return lora_config_dict return lora_config_dict
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional from typing import Any, Dict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic ...@@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
def is_rank_0() -> bool: def is_rank_0() -> bool:
...@@ -26,13 +26,13 @@ def get_world_size() -> int: ...@@ -26,13 +26,13 @@ def get_world_size() -> int:
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2': if model == "gpt2":
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'bloom': elif model == "bloom":
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank) actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'opt': elif model == "opt":
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'llama': elif model == "llama":
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank) actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
else: else:
raise ValueError(f'Unsupported actor model "{model}"') raise ValueError(f'Unsupported actor model "{model}"')
...@@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra ...@@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2': if model == "gpt2":
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'bloom': elif model == "bloom":
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'opt': elif model == "opt":
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'llama': elif model == "llama":
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
else: else:
raise ValueError(f'Unsupported reward model "{model}"') raise ValueError(f'Unsupported reward model "{model}"')
...@@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r ...@@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
def get_reward_model_from_args(model: str, pretrained: str = None, config=None): def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
if model == 'gpt2': if model == "gpt2":
reward_model = GPTRM(pretrained=pretrained, config=config) reward_model = GPTRM(pretrained=pretrained, config=config)
elif model == 'bloom': elif model == "bloom":
reward_model = BLOOMRM(pretrained=pretrained, config=config) reward_model = BLOOMRM(pretrained=pretrained, config=config)
elif model == 'opt': elif model == "opt":
reward_model = OPTRM(pretrained=pretrained, config=config) reward_model = OPTRM(pretrained=pretrained, config=config)
elif model == 'llama': elif model == "llama":
reward_model = LlamaRM(pretrained=pretrained, config=config) reward_model = LlamaRM(pretrained=pretrained, config=config)
else: else:
raise ValueError(f'Unsupported reward model "{model}"') raise ValueError(f'Unsupported reward model "{model}"')
...@@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None): ...@@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
def get_strategy_from_args(strategy: str): def get_strategy_from_args(strategy: str):
if strategy == 'ddp': if strategy == "ddp":
strategy_ = DDPStrategy() strategy_ = DDPStrategy()
elif strategy == 'colossalai_gemini': elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif strategy == 'colossalai_zero2': elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda') strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == 'colossalai_gemini_cpu': elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif strategy == 'colossalai_zero2_cpu': elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu') strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else: else:
raise ValueError(f'Unsupported strategy "{strategy}"') raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_ return strategy_
def get_tokenizer_from_args(model: str, **kwargs): def get_tokenizer_from_args(model: str, **kwargs):
if model == 'gpt2': if model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
elif model == 'bloom': elif model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
elif model == 'opt': elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif model == 'llama': elif model == "llama":
pretrain_path = kwargs["pretrain"] pretrain_path = kwargs["pretrain"]
tokenizer = AutoTokenizer.from_pretrained(pretrain_path) tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
else: else:
...@@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs): ...@@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
def set_dist_env(env_info: Dict[str, str]): def set_dist_env(env_info: Dict[str, str]):
os.environ["RANK"] = env_info['rank'] os.environ["RANK"] = env_info["rank"]
os.environ["LOCAL_RANK"] = env_info['local_rank'] os.environ["LOCAL_RANK"] = env_info["local_rank"]
os.environ["WORLD_SIZE"] = env_info['world_size'] os.environ["WORLD_SIZE"] = env_info["world_size"]
os.environ['MASTER_PORT'] = env_info['master_port'] os.environ["MASTER_PORT"] = env_info["master_port"]
os.environ['MASTER_ADDR'] = env_info['master_addr'] os.environ["MASTER_ADDR"] = env_info["master_addr"]
def get_model_numel(model: nn.Module) -> int: def get_model_numel(model: nn.Module) -> int:
...@@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i ...@@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
return target_receivers return target_receivers
def state_dict_to(state_dict: Dict[str, Any], def state_dict_to(
dtype: torch.dtype = torch.float16, state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
device: torch.device = torch.device('cpu')): ):
''' """
keep state_dict intact keep state_dict intact
''' """
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in state_dict.items(): for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device) new_state_dict[k] = v.to(dtype=dtype, device=device)
......
...@@ -3,8 +3,4 @@ from .ppo import PPOTrainer ...@@ -3,8 +3,4 @@ from .ppo import PPOTrainer
from .rm import RewardModelTrainer from .rm import RewardModelTrainer
from .sft import SFTTrainer from .sft import SFTTrainer
__all__ = [ __all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
'SLTrainer', 'OnPolicyTrainer',
'RewardModelTrainer', 'SFTTrainer',
'PPOTrainer'
]
...@@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC): ...@@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process callbacks (List[Callback], defaults to []): the callbacks to call during training process
""" """
def __init__(self, def __init__(
strategy: Strategy, self,
data_buffer: NaiveExperienceBuffer, strategy: Strategy,
sample_buffer: bool, data_buffer: NaiveExperienceBuffer,
dataloader_pin_memory: bool, sample_buffer: bool,
callbacks: List[Callback] = []) -> None: dataloader_pin_memory: bool,
callbacks: List[Callback] = [],
) -> None:
super().__init__() super().__init__()
self.strategy = strategy self.strategy = strategy
self.data_buffer = data_buffer self.data_buffer = data_buffer
......
...@@ -2,4 +2,4 @@ from .base import Callback ...@@ -2,4 +2,4 @@ from .base import Callback
from .performance_evaluator import PerformanceEvaluator from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint from .save_checkpoint import SaveCheckpoint
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint'] __all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
...@@ -5,7 +5,7 @@ from coati.experience_maker import Experience ...@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class Callback(ABC): class Callback(ABC):
""" """
Base callback class. It defines the interface for callbacks. Base callback class. It defines the interface for callbacks.
""" """
def on_fit_start(self) -> None: def on_fit_start(self) -> None:
......
...@@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None: ...@@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
def divide(x: float, y: float) -> float: def divide(x: float, y: float) -> float:
if y == 0: if y == 0:
return float('inf') return float("inf")
elif y == float('inf'): elif y == float("inf"):
return float('nan') return float("nan")
return x / y return x / y
...@@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: ...@@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer: class Timer:
def __init__(self) -> None: def __init__(self) -> None:
self.start_time: Optional[float] = None self.start_time: Optional[float] = None
self.duration: float = 0. self.duration: float = 0.0
def start(self) -> None: def start(self) -> None:
self.start_time = time() self.start_time = time()
...@@ -52,7 +51,7 @@ class Timer: ...@@ -52,7 +51,7 @@ class Timer:
self.start_time = None self.start_time = None
def reset(self) -> None: def reset(self) -> None:
self.duration = 0. self.duration = 0.0
class PerformanceEvaluator(Callback): class PerformanceEvaluator(Callback):
...@@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback): ...@@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
ignore_episodes: The number of episodes to ignore when calculating the performance. ignore_episodes: The number of episodes to ignore when calculating the performance.
""" """
def __init__(self, def __init__(
actor_num_params: int, self,
critic_num_params: int, actor_num_params: int,
initial_model_num_params: int, critic_num_params: int,
reward_model_num_params: int, initial_model_num_params: int,
enable_grad_checkpoint: bool = False, reward_model_num_params: int,
ignore_episodes: int = 0) -> None: enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0,
) -> None:
super().__init__() super().__init__()
self.world_size = get_world_size() self.world_size = get_world_size()
self.actor_num_params = actor_num_params self.actor_num_params = actor_num_params
...@@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback): ...@@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback):
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size) avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size) avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
avg_make_experience_throughput = self.make_experience_num_samples * \ avg_make_experience_throughput = (
self.world_size / (avg_make_experience_duration + 1e-12) self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12) avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
...@@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback): ...@@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback):
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples) learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0( print_rank_0(
f'Performance summary:\n' f"Performance summary:\n"
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n' + f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n' + f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' + f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n' + f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n' + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
) )
...@@ -36,34 +36,35 @@ class SaveCheckpoint(Callback): ...@@ -36,34 +36,35 @@ class SaveCheckpoint(Callback):
""" """
def __init__(self, def __init__(
path: str, self,
interval: int, path: str,
strategy: Strategy, interval: int,
actor: nn.Module = None, strategy: Strategy,
critic: nn.Module = None, actor: nn.Module = None,
actor_optim: Optimizer = None, critic: nn.Module = None,
critic_optim: Optimizer = None) -> None: actor_optim: Optimizer = None,
critic_optim: Optimizer = None,
) -> None:
super().__init__() super().__init__()
self.path = os.path.join(path, 'checkpoint') self.path = os.path.join(path, "checkpoint")
self.interval = interval self.interval = interval
self.strategy = strategy self.strategy = strategy
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]} self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None: def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0: if (episode + 1) % self.interval != 0:
return return
base_path = os.path.join(self.path, f'episode_{episode}') base_path = os.path.join(self.path, f"episode_{episode}")
if not os.path.exists(base_path): if not os.path.exists(base_path):
os.makedirs(base_path) os.makedirs(base_path)
for model in self.model_dict.keys(): for model in self.model_dict.keys():
# save model # save model
if self.model_dict[model][0] is None: if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped # saving only optimizer states is meaningless, so it would be skipped
continue continue
model_path = os.path.join(base_path, f'{model}.pt') model_path = os.path.join(base_path, f"{model}.pt")
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True) self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer # save optimizer
...@@ -71,5 +72,5 @@ class SaveCheckpoint(Callback): ...@@ -71,5 +72,5 @@ class SaveCheckpoint(Callback):
continue continue
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)) only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank() rank = 0 if is_rank_0() else dist.get_rank()
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt') optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0) self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
...@@ -8,7 +8,7 @@ from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss ...@@ -8,7 +8,7 @@ from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs from coati.models.utils import calc_action_log_probs
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
...@@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto ...@@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
hf_model = get_base_model(unwrapper_model) hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs} new_kwargs = {**generate_kwargs}
# use huggingface models method directly # use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'): if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'): if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
return new_kwargs return new_kwargs
...@@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer): ...@@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer):
generate_kwargs (dict, optional): the kwargs to use while model generating generate_kwargs (dict, optional): the kwargs to use while model generating
""" """
def __init__(self, def __init__(
strategy: Strategy, self,
actor: Actor, strategy: Strategy,
critic: Critic, actor: Actor,
reward_model: nn.Module, critic: Critic,
initial_model: Actor, reward_model: nn.Module,
actor_optim: Optimizer, initial_model: Actor,
critic_optim: Optimizer, actor_optim: Optimizer,
kl_coef: float = 0.1, critic_optim: Optimizer,
ptx_coef: float = 0.9, kl_coef: float = 0.1,
train_batch_size: int = 8, ptx_coef: float = 0.9,
buffer_limit: int = 0, train_batch_size: int = 8,
buffer_cpu_offload: bool = True, buffer_limit: int = 0,
eps_clip: float = 0.2, buffer_cpu_offload: bool = True,
vf_coef: float = 1.0, eps_clip: float = 0.2,
value_clip: float = 0.4, vf_coef: float = 1.0,
sample_buffer: bool = False, value_clip: float = 0.4,
dataloader_pin_memory: bool = True, sample_buffer: bool = False,
offload_inference_models: bool = True, dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [], offload_inference_models: bool = True,
**generate_kwargs callbacks: List[Callback] = [],
) -> None: **generate_kwargs,
) -> None:
if isinstance(strategy, GeminiStrategy): if isinstance(strategy, GeminiStrategy):
assert not offload_inference_models, \ assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
"GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__( super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
strategy, data_buffer,
sample_buffer, dataloader_pin_memory,
callbacks
)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
...@@ -130,18 +126,16 @@ class PPOTrainer(OnPolicyTrainer): ...@@ -130,18 +126,16 @@ class PPOTrainer(OnPolicyTrainer):
num_actions = experience.action_mask.size(1) num_actions = experience.action_mask.size(1)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask) actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions) action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(action_log_probs, actor_loss = self.actor_loss_fn(
experience.action_log_probs, action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
experience.advantages, )
action_mask=experience.action_mask)
# ptx loss # ptx loss
if self.ptx_coef != 0: if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next() batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device) batch = to_device(batch, self.device)
ptx_log_probs = self.actor(batch['input_ids'], ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
attention_mask=batch['attention_mask'])['logits'] ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
self.strategy.backward(actor_loss, self.actor, self.actor_optim) self.strategy.backward(actor_loss, self.actor, self.actor_optim)
...@@ -149,24 +143,23 @@ class PPOTrainer(OnPolicyTrainer): ...@@ -149,24 +143,23 @@ class PPOTrainer(OnPolicyTrainer):
self.actor_optim.zero_grad() self.actor_optim.zero_grad()
# value loss # value loss
values = self.critic(experience.sequences, values = self.critic(
action_mask=experience.action_mask, experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
attention_mask=experience.attention_mask) )
critic_loss = self.critic_loss_fn(values, critic_loss = self.critic_loss_fn(
experience.values, values, experience.values, experience.reward, action_mask=experience.action_mask
experience.reward, )
action_mask=experience.action_mask)
critic_loss = critic_loss * self.vf_coef critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim) self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
return {'reward': experience.reward.mean().item()} return {"reward": experience.reward.mean().item()}
def _learn(self, update_step: int): def _learn(self, update_step: int):
if self.offload_inference_models: if self.offload_inference_models:
self.experience_maker.initial_model.to('cpu') self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to('cpu') self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training # buffer may be empty at first, we should rebuild at each training
if self.sample_buffer: if self.sample_buffer:
...@@ -178,11 +171,7 @@ class PPOTrainer(OnPolicyTrainer): ...@@ -178,11 +171,7 @@ class PPOTrainer(OnPolicyTrainer):
else: else:
if isinstance(self.dataloader.sampler, DistributedSampler): if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step) self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm( pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
self.dataloader,
desc=f'Train epoch [{update_step + 1}]',
disable=not is_rank_0()
)
for experience in pbar: for experience in pbar:
self._on_learn_batch_start() self._on_learn_batch_start()
experience.to_device(self.device) experience.to_device(self.device)
......
...@@ -62,18 +62,15 @@ class RewardModelTrainer(SLTrainer): ...@@ -62,18 +62,15 @@ class RewardModelTrainer(SLTrainer):
if is_rank_0(): if is_rank_0():
log = pd.DataFrame( log = pd.DataFrame(
[[(epoch + 1) * len(self.train_dataloader), [[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
self.loss.item(), self.dist, self.acc]], columns=["step", "loss", "dist", "acc"],
columns=['step', 'loss', 'dist', 'acc']
) )
log.to_csv('log.csv', mode='a', header=False, index=False) log.to_csv("log.csv", mode="a", header=False, index=False)
def _train(self, epoch): def _train(self, epoch):
self.model.train() self.model.train()
step_bar = tqdm.trange( step_bar = tqdm.trange(
len(self.train_dataloader), len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0()
) )
cnt = 0 cnt = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
...@@ -93,10 +90,7 @@ class RewardModelTrainer(SLTrainer): ...@@ -93,10 +90,7 @@ class RewardModelTrainer(SLTrainer):
step_bar.update() step_bar.update()
step_bar.close() step_bar.close()
def _before_fit(self, def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader):
""" """
Args: Args:
train_dataloader (DataLoader): the dataloader to use for training train_dataloader (DataLoader): the dataloader to use for training
...@@ -104,7 +98,7 @@ class RewardModelTrainer(SLTrainer): ...@@ -104,7 +98,7 @@ class RewardModelTrainer(SLTrainer):
eval_dataloader (DataLoader): the dataloader to use for evaluation eval_dataloader (DataLoader): the dataloader to use for evaluation
""" """
super()._before_fit() super()._before_fit()
self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader self.valid_dataloader = valid_dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment