Unverified Commit c622bb36 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #3915 from FrankLeeeee/update/develop

[sync] update develop with main
parents 34966378 9c88b6cb
from abc import ABC
from coati.experience_maker import Experience
class TrainerCallback(ABC):
"""
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
pass
def on_fit_end(self) -> None:
pass
def on_episode_start(self, episode: int) -> None:
pass
def on_episode_end(self, episode: int) -> None:
pass
def on_epoch_start(self, epoch: int) -> None:
pass
def on_epoch_end(self, epoch: int) -> None:
pass
def on_batch_start(self) -> None:
pass
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
pass
def on_update_start(self) -> None:
pass
def on_update_end(self) -> None:
pass
class MakerCallback(ABC):
def on_loop_start(self) -> None:
pass
def on_loop_end(self) -> None:
pass
def on_make_experience_start(self) -> None:
pass
def on_make_experience_end(self, experience: Experience) -> None:
pass
def on_send_start(self) -> None:
pass
def on_send_end(self) -> None:
pass
def on_batch_start(self) -> None:
pass
def on_batch_end(self) -> None:
pass
from time import time
from typing import Optional
import torch
import torch.distributed as dist
from coati.experience_maker import Experience
from .base import MakerCallback, TrainerCallback
def get_world_size() -> int:
if dist.is_initialized():
return dist.get_world_size()
return 1
def print_rank_0(*args, **kwargs) -> None:
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
@torch.no_grad()
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
def start(self) -> None:
self.start_time = time()
def end(self) -> None:
self.duration += time() - self.start_time
def reset(self) -> None:
self.duration = 0.
class ExperienceMakerPerformanceEvaluator(MakerCallback):
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
reward_model_num_params: int) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
self.critic_num_params = critic_num_params
self.initial_model_num_params = initial_model_num_params
self.reward_model_num_params = reward_model_num_params
self.batch_timer = Timer()
self.send_timer = Timer()
self.make_experience_timer = Timer()
self.total_samples: int = 0
self.make_experience_flop: int = 0
print_rank_0(
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
)
def on_make_experience_start(self) -> None:
self.make_experience_timer.start()
def on_make_experience_end(self, experience: Experience) -> None:
self.make_experience_timer.end()
batch_size, seq_len = experience.sequences.shape
self.total_samples += batch_size
# actor generate
num_actions = experience.action_mask.size(1)
input_len = seq_len - num_actions
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
# actor forward
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
# critic forward
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
# initial model forward
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
# reward model forward
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
def on_send_start(self) -> None:
self.send_timer.start()
def on_send_end(self) -> None:
self.send_timer.end()
def on_batch_start(self) -> None:
self.batch_timer.start()
def on_batch_end(self) -> None:
self.batch_timer.end()
def on_loop_end(self) -> None:
avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size)
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
(self.total_samples * self.world_size)
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' +
f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)
class TrainerPerformanceEvaluator(TrainerCallback):
def __init__(self,
actor_num_params: int,
critic_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
self.critic_num_params = critic_num_params
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_first_episodes = ignore_first_episodes
self.ignore_this_episode = False
self.episode_timer = Timer()
self.batch_timer = Timer()
self.update_timer = Timer()
self.total_samples: int = 0
self.learn_flop: int = 0
print_rank_0(
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
)
def on_episode_start(self, episodes: int) -> None:
self.ignore_this_episode = episodes < self.ignore_first_episodes
if self.ignore_this_episode:
return
self.episode_timer.start()
def on_episode_end(self, episodes: int) -> None:
if self.ignore_this_episode:
return
self.episode_timer.end()
def on_batch_start(self) -> None:
if self.ignore_this_episode:
return
self.batch_timer.start()
def on_batch_end(self, metrics: dict, experience: Experience) -> None:
if self.ignore_this_episode:
return
self.batch_timer.end()
batch_size, seq_len = experience.sequences.shape
self.total_samples += batch_size
# actor forward-backward, 3 means forward(1) + backward(2)
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
# critic forward-backward
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
def on_update_start(self) -> None:
if self.ignore_this_episode:
return
self.update_timer.start()
def on_update_end(self) -> None:
if self.ignore_this_episode:
return
self.update_timer.end()
def on_fit_end(self) -> None:
if self.total_samples == 0:
print_rank_0('No samples are collected, skip trainer performance evaluation')
return
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size)
avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12)
avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12)
avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size)
avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size)
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' +
f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' +
f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+
f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
)
import torch import asyncio
import copy
import random import random
from typing import List, Any from threading import Lock
# from torch.multiprocessing import Queue from typing import Any, List
from ray.util.queue import Queue
import ray import ray
import asyncio import torch
from coati.experience_maker.base import Experience from coati.experience_maker.base import Experience
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from threading import Lock from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
import copy # from torch.multiprocessing import Queue
from ray.util.queue import Queue
class DetachedReplayBuffer: class DetachedReplayBuffer:
''' '''
Detached replay buffer. Share Experience across workers on the same node. Detached replay buffer. Share Experience across workers on the same node.
Therefore a trainer node is expected to have only one instance. Therefore a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely. It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
Args: Args:
sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch. sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch.
tp_world_size: Number of workers in the same tp group tp_world_size: Number of workers in the same tp group
...@@ -24,31 +26,25 @@ class DetachedReplayBuffer: ...@@ -24,31 +26,25 @@ class DetachedReplayBuffer:
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
''' '''
def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None: def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.cpu_offload = cpu_offload
self.sample_batch_size = sample_batch_size self.sample_batch_size = sample_batch_size
self.limit = limit self.limit = limit
self.items = Queue(self.limit, actor_options={"num_cpus":1}) self.items = Queue(self.limit, actor_options={"num_cpus": 1})
self.batch_collector : List[BufferItem] = [] self.batch_collector: List[BufferItem] = []
@torch.no_grad()
def append(self, experience: Experience) -> None:
''' '''
Workers in the same tp group share this buffer and need same sample for one step. Expected to be called remotely.
Therefore a held_sample should be returned tp_world_size times before it could be dropped.
worker_state records wheter a worker got the held_sample
''' '''
self.tp_world_size = tp_world_size items = split_experience_batch(experience)
self.worker_state = [False] * self.tp_world_size self.extend(items)
self.held_sample = None
self._worker_state_lock = Lock()
@torch.no_grad() @torch.no_grad()
def append(self, experience: Experience) -> None: def extend(self, items: List[BufferItem]) -> None:
''' '''
Expected to be called remotely. Expected to be called remotely.
''' '''
if self.cpu_offload:
experience.to_device(torch.device('cpu'))
items = split_experience_batch(experience)
self.batch_collector.extend(items) self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size: while len(self.batch_collector) >= self.sample_batch_size:
items = self.batch_collector[:self.sample_batch_size] items = self.batch_collector[:self.sample_batch_size]
...@@ -62,19 +58,10 @@ class DetachedReplayBuffer: ...@@ -62,19 +58,10 @@ class DetachedReplayBuffer:
self.items = Queue(self.limit) self.items = Queue(self.limit)
self.worker_state = [False] * self.tp_world_size self.worker_state = [False] * self.tp_world_size
self.batch_collector = [] self.batch_collector = []
@torch.no_grad() @torch.no_grad()
def sample(self, worker_rank = 0, to_device = "cpu") -> Experience: def sample(self, worker_rank=0, to_device="cpu") -> Experience:
self._worker_state_lock.acquire() ret = self._sample_and_erase()
if not any(self.worker_state):
self.held_sample = self._sample_and_erase()
self.worker_state[worker_rank] = True
if all(self.worker_state):
self.worker_state = [False] * self.tp_world_size
ret = self.held_sample
else:
ret = copy.deepcopy(self.held_sample)
self._worker_state_lock.release()
ret.to_device(to_device) ret.to_device(to_device)
return ret return ret
...@@ -85,4 +72,4 @@ class DetachedReplayBuffer: ...@@ -85,4 +72,4 @@ class DetachedReplayBuffer:
def get_length(self) -> int: def get_length(self) -> int:
ret = self.items.qsize() ret = self.items.qsize()
return ret return ret
\ No newline at end of file
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from tqdm import tqdm
from coati.trainer.callbacks import Callback
from coati.experience_maker import Experience
import ray import ray
import os import torch
from coati.experience_maker import Experience
from coati.replay_buffer.utils import BufferItem
from torch.utils.data import DataLoader
from tqdm import tqdm
from .callbacks import TrainerCallback
from .detached_replay_buffer import DetachedReplayBuffer from .detached_replay_buffer import DetachedReplayBuffer
from .utils import is_rank_0 from .utils import is_rank_0
class DetachedTrainer(ABC): class DetachedTrainer(ABC):
''' '''
Base class for detached rlhf trainers. Base class for detached rlhf trainers.
'detach' means that the experience maker is detached compared to a normal Trainer. 'detach' means that the experience maker is detached compared to a normal Trainer.
Please set name attribute during init: Please set name attribute during init:
>>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote() >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
...@@ -19,87 +24,116 @@ class DetachedTrainer(ABC): ...@@ -19,87 +24,116 @@ class DetachedTrainer(ABC):
Args: Args:
detached_strategy (DetachedStrategy): the strategy to use for training detached_strategy (DetachedStrategy): the strategy to use for training
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating generate_kwargs (dict, optional): the kwargs to use while model generating
''' '''
def __init__(self, def __init__(self,
experience_maker_holder_name_list: List[str], experience_maker_holder_name_list: List[str],
train_batch_size: int = 8, train_batch_size: int = 8,
buffer_limit: int = 0, buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
experience_batch_size: int = 8,
max_epochs: int = 1,
dataloader_pin_memory: bool = True, dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [], callbacks: List[TrainerCallback] = [],
**generate_kwargs) -> None: debug: bool = False) -> None:
super().__init__() super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload) self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.experience_batch_size = experience_batch_size
self.max_epochs = max_epochs
self.dataloader_pin_memory = dataloader_pin_memory self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks self.callbacks = callbacks
self.generate_kwargs = generate_kwargs
self.target_holder_name_list = experience_maker_holder_name_list self.target_holder_name_list = experience_maker_holder_name_list
self.target_holder_list = [] self.target_holder_list = []
self._is_target_holder_initialized = False
self._debug = debug
def update_target_holder_list(self, experience_maker_holder_name_list): def update_target_holder_list(self):
self.target_holder_name_list = experience_maker_holder_name_list # as the length of target_holder_list may be zero, we need to check it by a bool flag
self.target_holder_list = [] if not self._is_target_holder_initialized:
for name in self.target_holder_name_list: for name in self.target_holder_name_list:
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
self._is_target_holder_initialized = True
@abstractmethod @abstractmethod
def _update_remote_makers(self): def _update_remote_makers(self, fully_update: bool = False, **kwargs):
pass pass
def sync_models_to_remote_makers(self, **kwargs):
self._update_remote_makers(fully_update=True, **kwargs)
@abstractmethod @abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]: def training_step(self, experience: Experience) -> Dict[str, Any]:
pass pass
def _learn(self): def _learn(self, update_steps: int, train_epochs: int) -> None:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) data = []
for _ in pbar: # warmup
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
print("[trainer] sampling exp") self._on_epoch_start(0)
experience = self._buffer_sample() self._learn_epoch(pbar, data)
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: self._on_epoch_end(0)
# item is already a batch
dataloader = DataLoader(data,
batch_size=1,
shuffle=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=lambda x: x[0])
for epoch in range(1, train_epochs):
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
self._on_epoch_start(epoch)
self._learn_epoch(pbar, data)
self._on_epoch_end(epoch)
def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
is_warmup = len(data) == 0
for x in pbar:
if self._debug:
print("[trainer] training step") print("[trainer] training step")
# sample a batch and then train to avoid waiting
experience = x if not is_warmup else self._buffer_sample()
experience.to_device(torch.cuda.current_device())
self._on_batch_start()
metrics = self.training_step(experience) metrics = self.training_step(experience)
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: self._on_batch_end(metrics, experience)
if self._debug:
print("[trainer] step over") print("[trainer] step over")
experience.to_device("cpu")
if is_warmup:
data.append(experience)
pbar.set_postfix(metrics) pbar.set_postfix(metrics)
def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None: def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start() self._on_fit_start()
for episode in range(num_episodes): for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
self._on_episode_start(episode) self._on_episode_start(i)
for timestep in tqdm(range(max_timesteps // update_timesteps), self._learn(update_steps, train_epochs)
desc=f'Episode [{episode+1}/{num_episodes}]', self._on_update_start()
disable=not is_rank_0()): self._update_remote_makers()
self._learn() self._on_update_end()
self._update_remote_makers() self._on_episode_end(i)
self._on_episode_end(episode)
self._on_fit_end() self._on_fit_end()
@ray.method(concurrency_group="buffer_length") @ray.method(concurrency_group="buffer_length")
def buffer_get_length(self): def buffer_get_length(self):
# called by ExperienceMakerHolder # called by ExperienceMakerHolder
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: if self._debug:
print("[trainer] telling length") print("[trainer] telling length")
return self.detached_replay_buffer.get_length() return self.detached_replay_buffer.get_length()
@ray.method(concurrency_group="buffer_append") @ray.method(concurrency_group="buffer_append")
def buffer_append(self, experience: Experience): def buffer_append(self, experience: Experience):
# called by ExperienceMakerHolder # called by ExperienceMakerHolder
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: if self._debug:
# print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}")
print(f"[trainer] receiving exp.") print(f"[trainer] receiving exp.")
self.detached_replay_buffer.append(experience) self.detached_replay_buffer.append(experience)
@ray.method(concurrency_group="buffer_append")
def buffer_extend(self, items: List[BufferItem]):
# called by ExperienceMakerHolder
if self._debug:
print(f"[trainer] receiving exp.")
self.detached_replay_buffer.extend(items)
@ray.method(concurrency_group="buffer_sample") @ray.method(concurrency_group="buffer_sample")
def _buffer_sample(self): def _buffer_sample(self):
return self.detached_replay_buffer.sample() return self.detached_replay_buffer.sample()
...@@ -119,3 +153,27 @@ class DetachedTrainer(ABC): ...@@ -119,3 +153,27 @@ class DetachedTrainer(ABC):
def _on_episode_end(self, episode: int) -> None: def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks: for callback in self.callbacks:
callback.on_episode_end(episode) callback.on_episode_end(episode)
def _on_epoch_start(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_epoch_start(epoch)
def _on_epoch_end(self, epoch: int) -> None:
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def _on_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_batch_start()
def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_batch_end(metrics, experience)
def _on_update_start(self) -> None:
for callback in self.callbacks:
callback.on_update_start()
def _on_update_end(self) -> None:
for callback in self.callbacks:
callback.on_update_end()
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch.optim import Adam
import ray
import torch
from coati.experience_maker import Experience, NaiveExperienceMaker from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic from coati.models.base import Actor, Critic
from coati.models.generation_utils import update_model_kwargs_fn
from coati.models.loss import PolicyLoss, ValueLoss from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
from coati.trainer.callbacks import Callback from coati.trainer.callbacks import Callback
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
import ray from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env
from .detached_trainer_base import DetachedTrainer from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor
from .utils import (
@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1}) get_actor_from_args,
get_critic_from_args,
get_model_numel,
get_rank,
get_strategy_from_args,
is_rank_0,
set_dist_env,
state_dict_to,
)
@ray.remote(concurrency_groups={
"buffer_length": 1,
"buffer_append": 1,
"buffer_sample": 1,
"model_io": 1,
"compute": 1
})
class DetachedPPOTrainer(DetachedTrainer): class DetachedPPOTrainer(DetachedTrainer):
''' '''
Detached Trainer for PPO algorithm Detached Trainer for PPO algorithm
...@@ -40,86 +54,102 @@ class DetachedPPOTrainer(DetachedTrainer): ...@@ -40,86 +54,102 @@ class DetachedPPOTrainer(DetachedTrainer):
generate_kwargs (dict, optional): the kwargs to use while model generating generate_kwargs (dict, optional): the kwargs to use while model generating
''' '''
def __init__(self, def __init__(
experience_maker_holder_name_list: List[str], self,
strategy: str, experience_maker_holder_name_list: List[str],
model: str, strategy_fn: Callable[[], Strategy],
env_info: Dict[str, str] = None, model_fn: Callable[[], Tuple[Actor, Critic]],
pretrained: str = None, env_info: Dict[str, str] = None,
lora_rank: int = 0, train_batch_size: int = 8,
train_batch_size: int = 8, buffer_limit: int = 0,
buffer_limit: int = 0, eps_clip: float = 0.2,
buffer_cpu_offload: bool = True, value_clip: float = 0.4,
eps_clip: float = 0.2, dataloader_pin_memory: bool = True,
value_clip: float = 0.4, callbacks: List[TrainerCallback] = [],
experience_batch_size: int = 8, eval_performance: bool = False,
max_epochs: int = 10, debug: bool = False,
dataloader_pin_memory: bool = True, update_lora_weights: bool = False,
callbacks: List[Callback] = [], ) -> None:
**generate_kwargs) -> None:
# set environment variables # set environment variables
if env_info: if env_info:
set_dist_env(env_info=env_info) set_dist_env(env_info=env_info)
# configure strategy # configure strategy
self.strategy = get_strategy_from_args(strategy) self.strategy = strategy_fn()
# configure models, loss and optimizers # configure models, loss and optimizers
with self.strategy.model_init_context(): with self.strategy.model_init_context():
self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank) self.actor, self.critic = model_fn()
if strategy != 'colossalai_gemini': if eval_performance:
self.actor.to(torch.float16).to(torch.cuda.current_device()) actor_numel = get_model_numel(self.actor)
self.critic.to(torch.float16).to(torch.cuda.current_device()) critic_numel = get_model_numel(self.critic)
evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
callbacks = callbacks + [evaluator]
if strategy.startswith('colossalai'): if isinstance(self.strategy, ColossalAIStrategy):
self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6) self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6) self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
else: else:
self.actor_optim = Adam(self.actor.parameters(), lr=5e-6) self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=5e-6) self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \ (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)
# configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip) self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip) self.critic_loss_fn = ValueLoss(value_clip)
super().__init__(experience_maker_holder_name_list, super().__init__(experience_maker_holder_name_list,
train_batch_size=train_batch_size, train_batch_size=train_batch_size,
buffer_limit=buffer_limit, buffer_limit=buffer_limit,
buffer_cpu_offload=buffer_cpu_offload,
experience_batch_size=experience_batch_size,
max_epochs=max_epochs,
dataloader_pin_memory=dataloader_pin_memory, dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks, callbacks=callbacks,
**generate_kwargs) debug=debug)
if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
self._update_lora_weights = update_lora_weights
@ray.method(concurrency_group="model_io") @ray.method(concurrency_group="model_io")
def _update_remote_makers(self): @torch.no_grad()
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties # TODO: balance duties
if is_rank_0(): if not fully_update:
self.update_target_holder_list(self.target_holder_name_list) config['requires_grad_only'] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
for target_holder in self.target_holder_list:
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
ray.get(tasks)
# sending loop
tasks = []
for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config):
for target_holder in self.target_holder_list: for target_holder in self.target_holder_list:
# TODO: reduce malloc tasks.append(
with torch.no_grad(): target_holder.update_experience_maker.remote(
ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
@ray.method(concurrency_group="model_io") fully_update=fully_update))
def initialize_remote_makers(self): # sending loop
# TODO: balance duties for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
if is_rank_0():
self.update_target_holder_list(self.target_holder_name_list)
for target_holder in self.target_holder_list: for target_holder in self.target_holder_list:
# TODO: reduce malloc tasks.append(
with torch.no_grad(): target_holder.update_experience_maker.remote(
ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
fully_update=fully_update))
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
@ray.method(concurrency_group="compute") @ray.method(concurrency_group="compute")
def training_step(self, experience: Experience) -> Dict[str, float]: def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train() self.actor.train()
self.critic.train() self.critic.train()
experience.to_device(torch.cuda.current_device())
num_actions = experience.action_mask.size(1) num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs, actor_loss = self.actor_loss_fn(action_log_probs,
...@@ -155,38 +185,16 @@ class DetachedPPOTrainer(DetachedTrainer): ...@@ -155,38 +185,16 @@ class DetachedPPOTrainer(DetachedTrainer):
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None: def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_optimizer(self.critic_optim, path, only_rank0) self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
def _get_unwrapped_actor(self): def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config):
if False: for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
pass if not self._update_lora_weights or fully_update:
elif isinstance(self.strategy, ColossalAIStrategy): yield state_dict_to(state_dict)
ret = Actor(self.strategy._unwrap_model(self.actor)) else:
return ret state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
elif isinstance(self.strategy, DDPStrategy): yield state_dict_to(state_dict_lora)
return Actor(self.strategy._unwrap_actor(self.actor))
elif isinstance(self.strategy, NaiveStrategy): def _get_model_lora_config_dict(self, model: torch.nn.Module):
return self.actor if not self._update_lora_weights:
return None
def _get_unwrapped_critic(self): unwrapped_model = self.strategy.unwrap_model(model)
if False: return LoRAConstructor.extract_lora_config(unwrapped_model)
pass
elif isinstance(self.strategy, ColossalAIStrategy):
ret = self.strategy._unwrap_model(self.critic)
return ret
elif isinstance(self.strategy, DDPStrategy):
return self.critic.module
elif isinstance(self.strategy, NaiveStrategy):
return self.critic
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
return new_kwargs
\ No newline at end of file
import argparse
from copy import deepcopy
import pandas as pd
import torch
from coati.trainer import PPOTrainer
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.experience_maker import NaiveExperienceMaker
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
import ray
import os
import socket
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainer = {'local_rank' : '0',
'rank' : '0',
'world_size' : '1',
'master_port' : trainer_port,
'master_addr' : master_addr}
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {'local_rank' : '0',
'rank' : '0',
'world_size' : '1',
'master_port' : maker_port,
'master_addr' : master_addr}
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure Trainer
trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy=args.trainer_strategy,
model=args.model,
env_info = env_info_trainer,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=["trainer1"],
strategy=args.maker_strategy,
env_info = env_info_maker,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# trainer send its actor and critic to experience holders.
ray.get(trainer_ref.initialize_remote_makers.remote())
# configure sampler
dataset = pd.read_csv(args.prompt_path)['prompt']
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.cuda() for k, v in batch.items()}
trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
ray.get([trainer_done_ref, maker_done_ref])
# save model checkpoint after fitting
trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('prompt_path')
parser.add_argument('--trainer_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--maker_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"])
main(args)
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
export RAY_NAMESPACE="admin"
python 1m1t.py "/path/to/prompts.csv" \
--trainer_strategy colossalai_zero2 --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
--max_epochs 10 --debug
import argparse
from copy import deepcopy
import pandas as pd
import torch
from coati.trainer import PPOTrainer
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.experience_maker import NaiveExperienceMaker
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
import ray
import os
import socket
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainer_1 = {'local_rank' : '0',
'rank' : '0',
'world_size' : '2',
'master_port' : trainer_port,
'master_addr' : master_addr}
env_info_trainer_2 = {'local_rank' : '0',
'rank' : '1',
'world_size' : '2',
'master_port' : trainer_port,
'master_addr' : master_addr}
# maker_env_info
maker_port = str(get_free_port())
env_info_maker_1 = {'local_rank' : '0',
'rank' : '0',
'world_size' : '2',
'master_port' : maker_port,
'master_addr' : master_addr}
print([env_info_trainer_1,
env_info_trainer_2,
env_info_maker_1])
ray.init(dashboard_port = 1145)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure Trainer
trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy=args.trainer_strategy,
model=args.model,
env_info=env_info_trainer_1,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy=args.trainer_strategy,
model=args.model,
env_info=env_info_trainer_2,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug= args.debug,
)
# configure Experience Maker
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=["trainer1", "trainer2"],
strategy=args.maker_strategy,
env_info=env_info_maker_1,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# trainer send its actor and critic to experience holders.
# TODO: balance duty
ray.get(trainer_1_ref.initialize_remote_makers.remote())
# configure sampler
dataset = pd.read_csv(args.prompt_path)['prompt']
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.cuda() for k, v in batch.items()}
trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs * 2 + 3 # +3 for fault tolerance
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref])
# save model checkpoint after fitting
trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('prompt_path')
parser.add_argument('--trainer_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--maker_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
main(args)
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
export RAY_NAMESPACE="admin"
python 1m2t.py "/path/to/prompts.csv" --model gpt2 \
--maker_strategy naive --trainer_strategy ddp --lora_rank 2 \
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
--max_epochs 10 #--debug
\ No newline at end of file
import argparse
from copy import deepcopy
import pandas as pd
import torch
from coati.trainer import PPOTrainer
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.experience_maker import NaiveExperienceMaker
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
import ray
import os
import socket
def main(args):
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure Trainer
trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1", "maker2"],
strategy=args.trainer_strategy,
model=args.model,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# configure Experience Maker
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=["trainer1"],
strategy=args.maker_strategy,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=["trainer1"],
strategy=args.maker_strategy,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# trainer send its actor and critic to experience holders.
ray.get(trainer_ref.initialize_remote_makers.remote())
# configure sampler
dataset = pd.read_csv(args.prompt_path)['prompt']
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.cuda() for k, v in batch.items()}
trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs // 2 + 3 # +3 for fault tolerance
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
ray.get([trainer_done_ref, maker_1_done_ref, maker_2_done_ref])
# save model checkpoint after fitting
trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('prompt_path')
parser.add_argument('--trainer_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--maker_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"])
main(args)
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 3
export RAY_NAMESPACE="admin"
python 2m1t.py "/path/to/prompts.csv" \
--trainer_strategy naive --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
--max_epochs 10 # --debug
import argparse
from copy import deepcopy
import pandas as pd
import torch
from coati.trainer import PPOTrainer
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.experience_maker import NaiveExperienceMaker
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
import ray
import os
import socket
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainer_1 = {'local_rank' : '0',
'rank' : '0',
'world_size' : '2',
'master_port' : trainer_port,
'master_addr' : master_addr}
env_info_trainer_2 = {'local_rank' : '0',
'rank' : '1',
'world_size' : '2',
'master_port' : trainer_port,
'master_addr' : master_addr}
# maker_env_info
maker_port = str(get_free_port())
env_info_maker_1 = {'local_rank' : '0',
'rank' : '0',
'world_size' : '2',
'master_port' : maker_port,
'master_addr' : master_addr}
env_info_maker_2 = {'local_rank' : '0',
'rank' : '1',
'world_size' : '2',
'master_port': maker_port,
'master_addr' : master_addr}
print([env_info_trainer_1,
env_info_trainer_2,
env_info_maker_1,
env_info_maker_2])
ray.init()
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure Trainer
trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1", "maker2"],
strategy=args.trainer_strategy,
model=args.model,
env_info=env_info_trainer_1,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1", "maker2"],
strategy=args.trainer_strategy,
model=args.model,
env_info=env_info_trainer_2,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# configure Experience Maker
experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=["trainer1", "trainer2"],
strategy=args.maker_strategy,
env_info=env_info_maker_1,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=["trainer1", "trainer2"],
strategy=args.maker_strategy,
env_info=env_info_maker_2,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
#kwargs:
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug=args.debug,
)
# trainer send its actor and critic to experience holders.
# TODO: balance duty
ray.get(trainer_1_ref.initialize_remote_makers.remote())
# configure sampler
dataset = pd.read_csv(args.prompt_path)['prompt']
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.cuda() for k, v in batch.items()}
trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref, maker_2_done_ref])
# save model checkpoint after fitting
trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('prompt_path')
parser.add_argument('--trainer_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--maker_strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
main(args)
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
export RAY_NAMESPACE="admin"
python 2m2t.py "path/to/prompts.csv" \
--maker_strategy naive --trainer_strategy colossalai_zero2 --lora_rank 2 \
--num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
--max_epochs 10 --debug
\ No newline at end of file
import os
import time
import tracemalloc
from copy import deepcopy
from threading import Lock
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import ray
import torch
import torch.nn as nn
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
from coati.trainer.strategies.sampler import DistributedSampler
from ray.exceptions import GetTimeoutError
from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .utils import (get_model_numel,
get_rank,
get_world_size,
is_rank_0,
set_dist_env,
state_dict_to)
from .lora_constructor import LoRAConstructor
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
'''
Args:
detached_trainer_name_list: str list to get ray actor handles
strategy:
kl_coef: the coefficient of kl divergence loss
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
'''
def __init__(
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True,
kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
self.target_trainer_list = []
assert len(detached_trainer_name_list) > 0
self._detached_trainer_name_list = detached_trainer_name_list
self.strategy = strategy_fn()
self.buffer_cpu_offload = buffer_cpu_offload
self.kl_coef = kl_coef
# init models
with self.strategy.model_init_context():
actor, critic, reward_model, initial_model = model_fn()
self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor)
if eval_performance:
actor_numel = get_model_numel(actor)
critic_numel = get_model_numel(critic)
initial_model_numel = get_model_numel(initial_model)
reward_model_numel = get_model_numel(reward_model)
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
reward_model_numel)
callbacks = callbacks + [evaluator]
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
self.callbacks = callbacks
self._model_visit_lock = Lock()
self._is_fully_initialized = not sync_models_from_trainers
self._debug = debug
self._update_lora_weights = update_lora_weights
if self._update_lora_weights:
self.actor_lora_constructor = LoRAConstructor()
self.critic_lora_constructor = LoRAConstructor()
self.target_auto_balance = False
self._target_idx = 0
if self._debug:
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
if not self._is_fully_initialized:
print(f'[maker{get_rank()}] Waiting for INIT')
def _get_ready(self):
while not self._fully_initialized():
time.sleep(1.0)
def _fully_initialized(self):
return self._is_fully_initialized
def _init_target_trainer_list(self):
if len(self.target_trainer_list) > 0:
return
for name in self._detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
# copy from ../trainer/base.py
@ray.method(concurrency_group="compute")
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
@ray.method(concurrency_group="experience_io")
def _send_items(self, experience: Experience) -> None:
self._init_target_trainer_list()
items = split_experience_batch(experience)
items_per_trainer = [[] for _ in range(len(self.target_trainer_list))]
for item in items:
items_per_trainer[self._target_idx].append(item)
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
for i, target_trainer in enumerate(self.target_trainer_list):
if len(items_per_trainer[i]) > 0:
target_trainer.buffer_extend.remote(items_per_trainer[i])
def _inference_step(self, batch) -> None:
self._on_batch_start()
with self._model_visit_lock:
self._on_make_experience_start()
experience = self._make_experience(batch)
self._on_make_experience_end(experience)
self._on_send_start()
if self.buffer_cpu_offload:
experience.to_device('cpu')
self._send_items(experience)
self._on_send_end()
self._on_batch_end()
def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0):
"""Working loop of the experience maker.
Args:
dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader.
num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1.
num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0.
"""
self._get_ready()
self._on_loop_start()
dataloader = dataloader_fn()
if num_steps > 0:
# ignore num epochs
it = iter(dataloader)
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
try:
batch = next(it)
except StopIteration:
it = iter(dataloader)
batch = next(it)
self._inference_step(batch)
else:
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
for _ in range(num_epochs):
for batch in dataloader:
self._inference_step(batch)
pbar.update()
self._on_loop_end()
@ray.method(concurrency_group="model_io")
def update_experience_maker(self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None):
'''
called by trainer
chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
fully_update: Set True if you want to sync models when initializing
TODO: load_state_dict integrate with model-sharding strategy
'''
_watch_memory = self._debug
if chunk_start:
if self._debug:
print("[maker] UPDATE ")
if _watch_memory:
tracemalloc.start()
self._model_visit_lock.acquire()
with torch.no_grad():
if new_actor_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increasae = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increasae)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increasae = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increasae)
# the lock must be released after both actor and critic being updated
if chunk_end:
self._model_visit_lock.release()
if _watch_memory:
current, peak = tracemalloc.get_traced_memory()
print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
tracemalloc.stop()
if fully_update:
self._is_fully_initialized = True
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
callback.on_make_experience_start()
def _on_make_experience_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_make_experience_end(experience)
def _on_loop_start(self) -> None:
for callback in self.callbacks:
callback.on_loop_start()
def _on_loop_end(self) -> None:
for callback in self.callbacks:
callback.on_loop_end()
def _on_send_start(self) -> None:
for callback in self.callbacks:
callback.on_send_start()
def _on_send_end(self) -> None:
for callback in self.callbacks:
callback.on_send_end()
def _on_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_batch_start()
def _on_batch_end(self) -> None:
for callback in self.callbacks:
callback.on_batch_end()
def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
origin_model = actor.model
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
return new_kwargs
from typing import Any, Callable, Dict, List, Optional
from collections import OrderedDict
from dataclasses import dataclass
import torch
import torch.nn as nn
from loralib.layers import LoRALayer
from coati.models.lora import LoraLinear
@dataclass
class LoRAConfig:
r: int = 0
lora_alpha: int = 1
lora_dropout: float = 0
fan_in_fan_out: bool = False
class LoRAConstructor:
'''
Tools for reconstructing a model from a remote LoRA model.
(Transfering only LoRA data costs much less!)
Usage:
Step 1 (Sender):
filter_state_dict_lora()
Step 2 (Sender, Optional):
extract_lora_config()
Step 3 (Sender):
send state_dict_lora and lora_config_dict
Step 4 (Receiver):
reconstruct_increase()
Step 5 (Receiver):
load_state_dict_increase()
'''
def __init__(self):
self.lora_config_dict = None
def register_lora_config(self, lora_config_dict: Dict[str, Any]):
self.lora_config_dict = lora_config_dict
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
'''
xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually.
'''
if lora_config_dict is not None:
self.register_lora_config(lora_config_dict)
state_dict_increasae = OrderedDict()
config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items():
if k.rpartition('.')[-1] == 'lora_A':
lora_A = v
layer_prefix = k.rpartition('.')[0]
elif k.rpartition('.')[-1] == 'lora_B':
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
layer_prefix_2, config = next(config_iter)
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config)
state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None
else:
raise ValueError('unexpected key')
return state_dict_increasae
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w):
return w.T if config.fan_in_fan_out else w
if config.r > 0:
scaling = config.lora_alpha / config.r
weight_data_increase = T(lora_B @ lora_A) * scaling
return weight_data_increase
return 0
def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]):
'''
The final reconstruction step
'''
# naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False)
@staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
'''
if keep_non_lora, also return non_lora state_dict
'''
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict.items():
if 'lora_A' in k or 'lora_B' in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
if keep_non_lora:
return state_dict_lora, state_dict_non_lora
else:
return state_dict_lora, None
@staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
'''
extract LoraLinear model.
return OrderedDict(): name -> LoRAConfig
'''
lora_config_dict = OrderedDict()
for name, child in model.named_modules():
if isinstance(child, LoraLinear):
lora_config_dict[name] = LoRAConfig(r=child.r,
lora_alpha=child.lora_alpha,
lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out)
return lora_config_dict
import torch
from typing import Any, Callable, Dict, List, Optional, Union
import ray
from ray.exceptions import GetTimeoutError
from torch import Tensor
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel
from coati.trainer.strategies.sampler import DistributedSampler
from coati.trainer.strategies import Strategy
from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker
from copy import deepcopy
from threading import Lock
import time
import os
from .utils import is_rank_0, get_strategy_from_args, set_dist_env
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
'''
Args:
detached_trainer_name_list: str list to get ray actor handleskkk
strategy:
experience_batch_size: batch size of generated experience
kl_coef: the coefficient of kl divergence loss
'''
def __init__(self,
detached_trainer_name_list: List[str],
strategy: str,
env_info: Dict[str, str] = None,
experience_batch_size: int = 8,
kl_coef: float = 0.1,
**generate_kwargs):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
self.target_trainer_list = []
for name in detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
self.strategy_str = strategy
self.strategy = get_strategy_from_args(strategy)
self.experience_batch_size = experience_batch_size
self.kl_coef = kl_coef
self.generate_kwargs = generate_kwargs
# Need a trainer to give an actor and a critic via initialize_experience_maker(...)
actor, critic, reward_model, initial_model = None, None, None, None
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
self._model_visit_lock = Lock()
self.fully_initialized = False
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print('[maker] Waiting for INIT')
def _get_ready(self):
while not self.fully_initialized:
time.sleep(1.0)
def update_target_trainer_list(self, detached_trainer_name_list):
self.target_trainer_list = []
for name in detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name))
# copy from ../trainer/base.py
@ray.method(concurrency_group="compute")
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
self._get_ready()
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
@ray.method(concurrency_group="experience_io")
def _send_experience(self, experience):
'''
ignore it
# choose a trainer that has the least experience batch in its detached_replay_buffer
chosen_trainer = None
min_length = None
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print("[maker] choosing target trainer")
while chosen_trainer is None:
for target_trainer in self.target_trainer_list:
try:
temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1)
if min_length is None:
min_length = temp_length
chosen_trainer = target_trainer
else:
if temp_length < min_length:
min_length = temp_length
chosen_trainer = target_trainer
except GetTimeoutError:
pass
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print(f"[maker] sending exp to {chosen_trainer}")
chosen_trainer.buffer_append.remote(experience)
'''
#
if not hasattr(self, "_target_idx"):
self._target_idx = 0
chosen_trainer = self.target_trainer_list[self._target_idx]
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print(f"[maker] sending exp to {chosen_trainer}")
chosen_trainer.buffer_append.remote(experience)
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000):
self._get_ready()
sampler = self.strategy.setup_sampler(dataset)
for _ in range(times):
rand_prompts = sampler.sample(self.experience_batch_size)
if tokenizer is not None:
inputs = tokenizer(rand_prompts)
else:
inputs = rand_prompts
self._model_visit_lock.acquire()
experience = self._make_experience(inputs=inputs)
self._model_visit_lock.release()
self._send_experience(experience=experience)
@ray.method(concurrency_group="model_io")
def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic):
'''
called by trainer. Only once.
'''
# TODO: reduce malloc
if self.fully_initialized:
return
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print('[maker] INIT')
with torch.no_grad():
with self.strategy.model_init_context():
actor = init_actor
critic = init_critic
initial_model = deepcopy(actor)
reward_model = RewardModel(deepcopy(critic.model),
deepcopy(critic.value_head)).to(torch.cuda.current_device())
if self.strategy_str != 'colossalai_gemini':
actor.to(torch.float16).to(torch.cuda.current_device())
critic.to(torch.float16).to(torch.cuda.current_device())
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
self.experience_maker.actor = self.strategy.prepare(actor)
self.experience_maker.critic = self.strategy.prepare(critic)
self.experience_maker.initial_model = self.strategy.prepare(initial_model)
self.experience_maker.reward_model = self.strategy.prepare(reward_model)
self.fully_initialized = True
@ray.method(concurrency_group="model_io")
def update_experience_maker(self, new_actor: Actor, new_critic: Critic):
'''
called by trainer
'''
# TODO: reduce malloc
self._model_visit_lock.acquire()
with torch.no_grad():
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print("[maker] UPDATE ")
if self.strategy_str != 'colossalai_gemini':
new_actor.to(torch.float16).to(torch.cuda.current_device())
new_critic.to(torch.float16).to(torch.cuda.current_device())
self.experience_maker.actor = self.strategy.prepare(new_actor)
self.experience_maker.critic = self.strategy.prepare(new_critic)
self._model_visit_lock.release()
# WIP
from coati.trainer.strategies import Strategy
from coati.trainer.strategies import NaiveStrategy
from coati.models.base import Actor, RewardModel, Critic
import numpy as np
import torch
from torch._C._distributed_rpc import _is_current_rpc_agent_set
import colossalai
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from colossalai.pipeline.middleware.adaptor import get_fx_topology
import os
from functools import partial
import random
rpc_is_initialized = _is_current_rpc_agent_set
class PipelineModel(torch.nn.Module):
'''
Actor has 2 kinds of jobs: forward and generate.
better to just pipelinize the inner model
'''
def __init__(self,
model: torch.nn.Module,
stage_num: int,
num_microbatches: int,
data_kwargs = None,
):
super().__init__()
# create partition module
def create_partition_module(pp_rank:int, stage_num: int, model, data_kwargs):
model.eval()
tracer = ColoTracer()
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
annotated_model = balanced_split_pass(gm, stage_num)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return split_submodules[pp_rank + 1]
def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
return partition
self.inference_engine = OneFOneBPipelineEngine(
partition_fn=partial(partition, model, data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device='cuda',
)
def forward(self,
**model_inputs):
return self.inference_engine.forward_backward(**model_inputs, forward_only=True)
class PPStrategy(NaiveStrategy):
"""
Strategy for Pipeline inference (inference only!)
master node only
"""
def __init__(
self,
seed: int = 42
):
self.seed = seed
super().__init__()
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
ppg.set_global_info(rank = int(os.environ['RANK']),
world_size=int(os.environ['WORLD_SIZE']),
dp_degree=1,
tp_degree=1,
num_worker_threads=128,
device="cuda")
def model_init_context(self):
return super().model_init_context()
def setup_model(self, model: torch.nn.Module) -> torch.nn.Module:
if isinstance(model, Actor) or \
isinstance(model, RewardModel) or \
isinstance(model, Critic):
model.model = PipelineModel(model.model)
def set_seed(self, seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
import torch.distributed as dist
from typing import Any, Callable, Dict, List, Optional
from coati.models.bloom import BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTActor, GPTCritic
from coati.models.opt import OPTActor, OPTCritic
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
import torch
import os
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0):
if model == 'gpt2':
actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
elif model == 'bloom':
actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
elif model == 'opt':
actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{model}"')
return actor, critic
def get_strategy_from_args(strategy: str):
if strategy == 'naive':
strategy_ = NaiveStrategy()
elif strategy == 'ddp':
strategy_ = DDPStrategy()
elif strategy == 'colossalai_gemini':
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
def set_dist_env(env_info: Dict[str, str]):
os.environ["RANK"] = env_info['rank']
os.environ["LOCAL_RANK"] = env_info['local_rank']
os.environ["WORLD_SIZE"] = env_info['world_size']
os.environ['MASTER_PORT'] = env_info['master_port']
os.environ['MASTER_ADDR'] = env_info['master_addr']
import os
from typing import Any, Callable, Dict, List, Optional
from collections import OrderedDict
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def get_rank() -> int:
return dist.get_rank() if dist.is_initialized() else 0
def get_world_size() -> int:
return dist.get_world_size() if dist.is_initialized() else 1
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'bloom':
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'opt':
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'llama':
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'roberta':
actor = RoBERTaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
else:
raise ValueError(f'Unsupported actor model "{model}"')
return actor
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'bloom':
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'opt':
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'llama':
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'roberta':
critic = RoBERTaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{model}"')
return critic
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
if model == 'gpt2':
reward_model = GPTRM(pretrained=pretrained, config=config)
elif model == 'bloom':
reward_model = BLOOMRM(pretrained=pretrained, config=config)
elif model == 'opt':
reward_model = OPTRM(pretrained=pretrained, config=config)
elif model == 'llama':
reward_model = LlamaRM(pretrained=pretrained, config=config)
elif model == 'roberta':
reward_model = RoBERTaRM(pretrained=pretrained, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
return reward_model
def get_strategy_from_args(strategy: str):
if strategy == 'naive':
strategy_ = NaiveStrategy()
elif strategy == 'ddp':
strategy_ = DDPStrategy()
elif strategy == 'colossalai_gemini':
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
elif strategy == 'colossalai_gemini_cpu':
strategy_ = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
elif strategy == 'colossalai_zero2_cpu':
strategy_ = ColossalAIStrategy(stage=2, placement_policy='cpu')
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
def get_tokenizer_from_args(model: str, **kwargs):
if model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
elif model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif model == 'llama':
pretrain_path = kwargs["pretrain"]
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
elif model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
else:
raise ValueError(f'Unsupported model "{model}"')
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def set_dist_env(env_info: Dict[str, str]):
os.environ["RANK"] = env_info['rank']
os.environ["LOCAL_RANK"] = env_info['local_rank']
os.environ["WORLD_SIZE"] = env_info['world_size']
os.environ['MASTER_PORT'] = env_info['master_port']
os.environ['MASTER_ADDR'] = env_info['master_addr']
def get_model_numel(model: nn.Module) -> int:
numel = sum(p.numel() for p in model.parameters())
return numel
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
target_receivers = []
if num_senders <= num_receivers or allow_idle_sender:
# a sender will send data to one or more than one receivers
# a receiver only has one sender
for i in range(num_receivers):
if i % num_senders == sender_idx:
target_receivers.append(i)
else:
# a sender will send data to one receiver
# a receiver may have more than one sender
target_receivers.append(sender_idx % num_receivers)
return target_receivers
def state_dict_to(state_dict: Dict[str, Any],
dtype: torch.dtype = torch.float16,
device: torch.device = torch.device('cpu')):
'''
keep state_dict intact
'''
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)
return new_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment