"tests/test_trainer/vscode:/vscode.git/clone" did not exist on "45355a62f7a398990683c02908136abcddb82ad7"
Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
import os
from collections import OrderedDict
from typing import Any, Dict
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def get_rank() -> int:
return dist.get_rank() if dist.is_initialized() else 0
def get_world_size() -> int:
return dist.get_world_size() if dist.is_initialized() else 1
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == "gpt2":
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == "bloom":
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == "opt":
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == "llama":
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
else:
raise ValueError(f'Unsupported actor model "{model}"')
return actor
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == "gpt2":
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "bloom":
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "opt":
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
elif model == "llama":
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
return critic
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
if model == "gpt2":
reward_model = GPTRM(pretrained=pretrained, config=config)
elif model == "bloom":
reward_model = BLOOMRM(pretrained=pretrained, config=config)
elif model == "opt":
reward_model = OPTRM(pretrained=pretrained, config=config)
elif model == "llama":
reward_model = LlamaRM(pretrained=pretrained, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
return reward_model
def get_strategy_from_args(strategy: str):
if strategy == "ddp":
strategy_ = DDPStrategy()
elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
def get_tokenizer_from_args(model: str, **kwargs):
if model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
elif model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif model == "llama":
pretrain_path = kwargs["pretrain"]
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
else:
raise ValueError(f'Unsupported model "{model}"')
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def set_dist_env(env_info: Dict[str, str]):
os.environ["RANK"] = env_info["rank"]
os.environ["LOCAL_RANK"] = env_info["local_rank"]
os.environ["WORLD_SIZE"] = env_info["world_size"]
os.environ["MASTER_PORT"] = env_info["master_port"]
os.environ["MASTER_ADDR"] = env_info["master_addr"]
def get_model_numel(model: nn.Module) -> int:
numel = sum(p.numel() for p in model.parameters())
return numel
def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
target_receivers = []
if num_senders <= num_receivers or allow_idle_sender:
# a sender will send data to one or more receivers
# a receiver only has one sender
for i in range(num_receivers):
if i % num_senders == sender_idx:
target_receivers.append(i)
else:
# a sender will send data to one receiver
# a receiver may have more than one sender
target_receivers.append(sender_idx % num_receivers)
return target_receivers
def state_dict_to(
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
):
"""
keep state_dict intact
"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)
return new_state_dict
from .base import ReplayBuffer
from .naive import NaiveReplayBuffer
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
from .base import Trainer
from .base import OnPolicyTrainer, SLTrainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from contextlib import contextmanager
from typing import List
import torch
import torch.nn as nn
import tqdm
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
from torch.optim import Optimizer
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0
class Trainer(ABC):
class SLTrainer(ABC):
"""
Base class for rlhf trainers.
Base class for supervised learning trainers.
Args:
strategy (Strategy):the strategy to use for training
max_epochs (int, defaults to 1): the number of epochs of training process
model (nn.Module): the model to train
optim (Optimizer): the optimizer to use for training
"""
def __init__(
self,
strategy: Strategy,
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
) -> None:
super().__init__()
self.strategy = strategy
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
@abstractmethod
def _train(self, epoch):
raise NotImplementedError()
@abstractmethod
def _eval(self, epoch):
raise NotImplementedError()
def _before_fit(self):
raise NotImplementedError()
def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
self._train(epoch)
self._eval(epoch)
class OnPolicyTrainer(ABC):
"""
Base class for on-policy rl trainers, e.g. PPO.
Args:
strategy (Strategy):the strategy to use for training
data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
max_epochs: int = 1,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
def __init__(
self,
strategy: Strategy,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = [],
) -> None:
super().__init__()
self.strategy = strategy
self.max_epochs = max_epochs
self.generate_kwargs = generate_kwargs
self.data_buffer = data_buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
# TODO(ver217): maybe simplify these code using context
def _on_fit_start(self) -> None:
@contextmanager
def _fit_ctx(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
def _on_fit_end(self) -> None:
for callback in self.callbacks:
callback.on_fit_end()
def _on_episode_start(self, episode: int) -> None:
try:
yield
finally:
for callback in self.callbacks:
callback.on_fit_end()
@contextmanager
def _episode_ctx(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_end(episode)
try:
yield
finally:
for callback in self.callbacks:
callback.on_episode_end(episode)
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
......@@ -70,6 +122,67 @@ class Trainer(ABC):
for callback in self.callbacks:
callback.on_learn_batch_start()
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
def _on_learn_batch_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience)
callback.on_learn_batch_end(experience)
@abstractmethod
def _make_experience(self, collect_step: int):
"""
Implement this method to make experience.
"""
raise NotImplementedError()
@abstractmethod
def _learn(self, update_step: int):
"""
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise NotImplementedError()
def _collect_phase(self, collect_step: int):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
self._on_make_experience_end(experience)
self.data_buffer.append(experience)
def _update_phase(self, update_step: int):
self._on_learn_epoch_start(update_step)
self._learn(update_step)
self._on_learn_epoch_end(update_step)
def _before_fit(self, *args, **kwargs):
raise NotImplementedError()
def fit(
self,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
*args,
**kwargs,
):
"""
The main training loop of on-policy rl trainers.
Args:
num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode
"""
self._before_fit(*args, **kwargs)
with self._fit_ctx():
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.data_buffer.clear()
......@@ -2,4 +2,4 @@ from .base import Callback
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
......@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class Callback(ABC):
"""
Base callback class. It defines the interface for callbacks.
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
......@@ -35,5 +35,5 @@ class Callback(ABC):
def on_learn_batch_start(self) -> None:
pass
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
def on_learn_batch_end(self, experience: Experience) -> None:
pass
......@@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
def divide(x: float, y: float) -> float:
if y == 0:
return float('inf')
elif y == float('inf'):
return float('nan')
return float("inf")
elif y == float("inf"):
return float("nan")
return x / y
......@@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
......@@ -52,7 +51,7 @@ class Timer:
self.start_time = None
def reset(self) -> None:
self.duration = 0.
self.duration = 0.0
class PerformanceEvaluator(Callback):
......@@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(self,
actor_num_params: int,
critic_num_params: int,
initial_model_num_params: int,
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0) -> None:
def __init__(
self,
actor_num_params: int,
critic_num_params: int,
initial_model_num_params: int,
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0,
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
......@@ -136,7 +137,7 @@ class PerformanceEvaluator(Callback):
return
self.learn_timer.start()
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
def on_learn_batch_end(self, experience: Experience) -> None:
if self.disable:
return
self.learn_timer.end()
......@@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback):
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
avg_make_experience_throughput = self.make_experience_num_samples * \
self.world_size / (avg_make_experience_duration + 1e-12)
avg_make_experience_throughput = (
self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
......@@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback):
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
f'Performance summary:\n' +
f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
+
f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' +
f'Overall time per sample: {overall_time_per_sample:.2f} s\n' +
f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
+
f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
f"Performance summary:\n"
+ f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
+ f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
)
import os
import torch.distributed as dist
from coati.trainer.strategies import ColossalAIStrategy, Strategy
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from coati.trainer.utils import is_rank_0
from torch import nn
from torch.optim import Optimizer
......@@ -36,40 +36,41 @@ class SaveCheckpoint(Callback):
"""
def __init__(self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None) -> None:
def __init__(
self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None,
) -> None:
super().__init__()
self.path = os.path.join(path, 'checkpoint')
self.path = os.path.join(path, "checkpoint")
self.interval = interval
self.strategy = strategy
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
base_path = os.path.join(self.path, f'episode_{episode}')
base_path = os.path.join(self.path, f"episode_{episode}")
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
model_path = os.path.join(base_path, f'{model}.pt')
model_path = os.path.join(base_path, f"{model}.pt")
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
if self.model_dict[model][1] is None:
continue
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank()
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from coati.models.utils import calc_action_log_probs
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase
from colossalai.utils import get_current_device
from .base import Trainer
from .base import OnPolicyTrainer
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0, to_device
from .strategies import GeminiStrategy, Strategy
from .utils import CycledDataLoader, is_rank_0, to_device
class PPOTrainer(Trainer):
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapped_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapped_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
return new_kwargs
class PPOTrainer(OnPolicyTrainer):
"""
Trainer for PPO algorithm.
......@@ -28,60 +40,61 @@ class PPOTrainer(Trainer):
strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
buffer_limit (int, defaults to 0): the max_size limitation of buffer
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
max_epochs (int, defaults to 1): the number of epochs of training process
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: nn.Module,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
max_epochs: int = 1,
sample_replay_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs)
self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.sample_replay_buffer = sample_replay_buffer
self.offload_inference_models = offload_inference_models
def __init__(
self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: RewardModel,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
tokenizer: PreTrainedTokenizerBase,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs,
) -> None:
if isinstance(strategy, GeminiStrategy):
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
self.actor = actor
self.critic = critic
self.tokenizer = tokenizer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
......@@ -91,123 +104,99 @@ class PPOTrainer(Trainer):
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.offload_inference_models = offload_inference_models
self.device = get_current_device()
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
experience.to_device(self.device)
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)
def fit(self,
prompt_dataloader,
pretrain_dataloader,
num_episodes: int = 50000,
max_timesteps: int = 500,
update_timesteps: int = 5000) -> None:
time = 0
self.pretrain_dataloader = pretrain_dataloader
self.prompt_dataloader = prompt_dataloader
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
prompts = next(iter(self.prompt_dataloader))
self._on_make_experience_start()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
experience = self._make_experience(prompts)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
if self.offload_inference_models:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()
def training_step(self, experience: Experience) -> Dict[str, float]:
def _before_fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-ppo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "ppo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _make_experience(self, collect_step: int) -> Experience:
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
def _training_step(self, experience: Experience):
self.actor.train()
self.critic.train()
# policy loss
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
num_actions = experience.action_log_probs.size(1)
actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
actor_loss = (1 - self.ptx_coef) * actor_loss
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
# ptx loss
if self.ptx_coef != 0:
batch = next(iter(self.pretrain_dataloader))
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
ptx_log_probs = self.actor.get_base_model()(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
# value loss
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'reward': experience.reward.mean().item()}
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy.unwrap_model(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
return new_kwargs
def _learn(self, update_step: int):
if self.offload_inference_models:
self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
from datetime import datetime
from typing import List, Optional
from typing import Callable, Optional
import pandas as pd
import torch
import torch.distributed as dist
from torch.optim import Optimizer, lr_scheduler
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Trainer
from .callbacks import Callback
import tqdm
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from .base import SLTrainer
from .strategies import Strategy
from .utils import is_rank_0
class RewardModelTrainer(Trainer):
class RewardModelTrainer(SLTrainer):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
optim (Optimizer): the optimizer to use for training
lr_scheduler (_LRScheduler): the lr scheduler to use for training
loss_fn (callable): the loss function to use for training
train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataloader (DataLoader): the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def __init__(
......@@ -37,87 +29,95 @@ class RewardModelTrainer(Trainer):
model,
strategy: Strategy,
optim: Optimizer,
loss_fn,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader,
lr_scheduler: _LRScheduler,
loss_fn: Callable,
max_epochs: int = 1,
callbacks: List[Callback] = [],
) -> None:
super().__init__(strategy, max_epochs, callbacks=callbacks)
super().__init__(strategy, max_epochs, model, optim)
self.loss_fn = loss_fn
self.scheduler = lr_scheduler
self.num_train_step = 0
def _eval(self, epoch):
if self.eval_dataloader is not None:
self.model.eval()
dist, num_correct, num_samples = 0, 0, 0
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
num_samples += chosen_ids.size(0)
num_correct += (chosen_reward > reject_reward).sum().item()
dist += (chosen_reward - reject_reward).mean().item()
self.dist = dist / len(self.eval_dataloader)
self.acc = num_correct / num_samples
if self.writer:
self.writer.add_scalar("eval/dist", self.dist, epoch)
self.writer.add_scalar("eval/acc", self.acc, epoch)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
)
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
if self.writer:
self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
self.writer.add_scalar(
"train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
)
self.num_train_step += 1
if self.num_train_step % 100 == 0:
self.scheduler.step()
step_bar.update()
step_bar.close()
def _before_fit(
self,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader
self.model = model
self.loss_fn = loss_fn
self.optimizer = optim
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)
def eval_acc(self, dataloader):
dist = 0
on = 0
cnt = 0
self.model.eval()
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
for i in range(len(chosen_reward)):
cnt += 1
if chosen_reward[i] > reject_reward[i]:
on += 1
dist += (chosen_reward - reject_reward).mean().item()
dist_mean = dist / len(dataloader)
acc = on / cnt
self.model.train()
return dist_mean, acc
def fit(self):
time = datetime.now()
epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.max_epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0())
# train
self.model.train()
cnt = 0
acc = 0
dist = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
cnt += 1
if cnt == 100:
self.scheduler.step()
dist, acc = self.eval_acc(self.valid_dataloader)
cnt = 0
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
step_bar.update()
step_bar.set_postfix({'dist': dist, 'acc': acc})
# eval
dist, acc = self.eval_acc(self.eval_dataloader)
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log.csv', mode='a', header=False, index=False)
epoch_bar.update()
step_bar.set_postfix({'dist': dist, 'acc': acc})
step_bar.close()
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-rm", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "rm")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
import math
import time
from typing import List, Optional
from typing import Optional
import torch
import torch.distributed as dist
import wandb
import tqdm
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler
from .base import Trainer
from .callbacks import Callback
from .strategies import ColossalAIStrategy, Strategy
from colossalai.logging import DistributedLogger
from .base import SLTrainer
from .strategies import GeminiStrategy, Strategy
from .utils import is_rank_0, to_device
class SFTTrainer(Trainer):
class SFTTrainer(SLTrainer):
"""
Trainer to use while training reward model.
......@@ -25,12 +22,9 @@ class SFTTrainer(Trainer):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
lr_scheduler(_LRScheduler): the lr scheduler to use for training
max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
"""
def __init__(
......@@ -38,98 +32,99 @@ class SFTTrainer(Trainer):
model,
strategy: Strategy,
optim: Optimizer,
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
accumulation_steps: int = 8,
callbacks: List[Callback] = [],
) -> None:
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
super().__init__(strategy, max_epochs, callbacks=callbacks)
if accumulation_steps > 1:
assert not isinstance(
strategy, GeminiStrategy
), "Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, model, optim)
self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
self.num_train_step = 0
self.num_eval_step = 0
def _train(self, epoch: int):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss / self.accumulation_steps
self.total_loss += loss.item()
self.strategy.backward(loss, self.model, self.optimizer)
# gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if self.writer:
self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.total_loss = 0
step_bar.update()
step_bar.close()
def _eval(self, epoch: int):
if self.eval_dataloader is not None:
self.model.eval()
with torch.no_grad():
loss_sum, num_seen = 0, 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
)
loss_sum += outputs.loss.item()
num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
if self.writer:
self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
self.num_eval_step += 1
def _before_fit(
self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
logger: Optional[DistributedLogger] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
"""
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.model = model
self.optimizer = optim
self.accumulation_steps = accumulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
self.scheduler = get_scheduler("cosine",
self.optimizer,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
def fit(self, logger, use_wandb: bool = False):
if use_wandb:
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
wandb.watch(self.model)
total_loss = 0
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
desc=f'steps',
disable=not is_rank_0())
for epoch in range(self.max_epochs):
# process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
# train
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
if loss >= 2.5 and is_rank_0():
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
loss = loss / self.accumulation_steps
self.strategy.backward(loss, self.model, self.optimizer)
total_loss += loss.item()
# gradient accumulation
if (batch_id + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and use_wandb:
wandb.log({
"loss": total_loss / self.accumulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id
})
total_loss = 0
step_bar.update()
# if batch_id % log_interval == 0:
# logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
# wandb.log({"loss": loss.item()})
# process_bar.update()
# eval
if self.eval_dataloader is not None:
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
loss = outputs.loss
loss_sum += loss.item()
num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
# epoch_bar.update()
self.logger = logger
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-sft", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "sft")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
self.total_loss = 0
from .base import Strategy
from .colossalai import ColossalAIStrategy
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy
from .naive import NaiveStrategy
__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']
__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Any, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from coati.models.base import Actor, get_base_model
from coati.replay_buffer import ReplayBuffer
from coati.experience_buffer import ExperienceBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster.plugin import Plugin
from .sampler import DistributedSampler
ModelOptimPair = Tuple[nn.Module, Optimizer]
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
class Strategy(ABC):
"""
Base class for training strategies.
Base class for training strategies.
"""
def __init__(self) -> None:
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
super().__init__()
# NOTE: dist must be initialized before Booster
self.setup_distributed()
self.plugin = plugin_initializer()
self.booster = Booster(plugin=self.plugin)
self._post_init()
@abstractmethod
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
def _post_init(self) -> None:
pass
@abstractmethod
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
self.booster.backward(loss, optimizer)
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
pass
optimizer.step()
@abstractmethod
def setup_distributed(self) -> None:
pass
@abstractmethod
def setup_model(self, model: nn.Module) -> nn.Module:
pass
@abstractmethod
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
pass
@abstractmethod
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
pass
def model_init_context(self):
return nullcontext()
def prepare(
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
"""Prepare models or model-optimizer-pairs based on each strategy.
def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
"""Prepare [model | (model, optimizer) | Dict] based on each strategy.
NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
Example::
>>> # e.g., include lr_scheduler
>>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
>>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model
......@@ -66,67 +66,72 @@ class Strategy(ABC):
>>> actor, critic = strategy.prepare(actor, critic)
Returns:
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
"""
def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(model.get_base_model()))
return self.setup_model(model)
rets = []
for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
for arg in boost_args:
if isinstance(arg, nn.Module):
model, *_ = self.booster.boost(arg)
rets.append(model)
elif isinstance(arg, tuple):
try:
model, optimizer = arg
except ValueError:
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
boost_result = dict(
model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler,
)
# remove None values
boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result)
else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
raise RuntimeError(f"Type {type(arg)} is not supported")
if len(rets) == 1:
return rets[0]
return rets
return rets[0] if len(rets) == 1 else rets
@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
For Actor, it will unwrap `actor.model`.
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
Args:
model (nn.Module): the model to unwrap
Returns:
nn.Module: the original model (usually a huggingface model)
nn.Module: the original model
"""
return get_base_model(model)
return model
@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
pass
def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
self.booster.save_model(model, path, shard=shard, **kwargs)
@abstractmethod
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
pass
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)
@abstractmethod
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
pass
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
@abstractmethod
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
pass
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
self.booster.load_optimizer(optimizer, path)
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, 1, 0)
@abstractmethod
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
pass
@abstractmethod
def get_model_state_dict_shard(self, model: nn.Module, **config):
pass
import warnings
from typing import Optional, Union
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from coati.models.base import get_base_model
from torch.optim import Optimizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy
logger = get_dist_logger(__name__)
class LowLevelZeroStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2)
precision(str): The precision to use. Choose in ('fp32', 'fp16').
seed(int): The seed for the random number generator.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
growth_interval(int): The growth interval for the optimizer.
hysteresis(int): The hysteresis for the optimizer.
min_scale(float): The minimum scale for the optimizer.
max_scale(float): The maximum scale for the optimizer.
max_norm(float): The maximum norm for the optimizer.
norm_type(float): The norm type for the optimizer.
"""
def __init__(
self,
stage: int = 2,
precision: str = "fp16",
seed: int = 42,
placement_policy: str = "cuda",
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
) -> None:
assert stage in (1, 2), f'Unsupported stage "{stage}"'
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin(
stage=stage,
precision=precision,
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == "cpu"),
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(
self.plugin, LowLevelZeroPlugin
), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, LowLevelZeroModel)
return model.module
def get_model_state_dict_shard(self, model: nn.Module, **config):
assert isinstance(model, LowLevelZeroModel)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
class ColossalAIStrategy(DDPStrategy):
class GeminiStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
......@@ -55,134 +125,76 @@ class ColossalAIStrategy(DDPStrategy):
"""
def __init__(
self,
stage: int = 3,
precision: str = 'fp16',
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
scatter_after_inference: bool = False, # only for stage 3
search_range_mb: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_mb: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0) -> None:
super().__init__(seed)
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
self.stage = stage
self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = "auto",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
) -> None:
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
f"Shard init is not supported model.from_pretrained() yet. "
"Please load weights after strategy.prepare()"
)
if stage == 3 and precision == 'fp32':
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
precision = 'fp16'
self.precision = precision
self.shard_init = shard_init
self.gemini_config = dict(device=get_current_device(),
placement_policy=placement_policy,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_mb=search_range_mb,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
scatter_after_inference=scatter_after_inference)
if stage == 3:
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
else:
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'))
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=get_current_device(),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
if self.stage == 3:
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_pg=shard_pg,
default_dist_spec=default_dist_spec)
return super().model_init_context()
def setup_model(self, model: nn.Module) -> nn.Module:
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
if self.stage != 3 and self.precision == 'fp16':
model = model.half().cuda()
return model
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.backward(loss)
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
return
base_model = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
# for stage 3, state_dict() method should be called on every rank
state_dict = base_model.state_dict(only_rank_0=only_rank0)
else:
# only_rank0 is false or rank == 0
state_dict = base_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0:
raise RuntimeError(
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
torch.save(optimizer.state_dict(), path)
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
return base_model.module
return base_model
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if self.stage == 3:
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
super().save_pretrained(model, path, only_rank0, tokenizer)
assert isinstance(model, GeminiDDP)
return model.module
import os
import random
from typing import Optional
from collections import OrderedDict
from typing import Callable, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from coati.experience_buffer import ExperienceBuffer
from coati.models import Actor, Critic, RewardModel
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .naive import NaiveStrategy
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel
from .base import Strategy
from .sampler import DistributedSampler
class DDPStrategy(NaiveStrategy):
# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
def get_grad_required_state_dict(model: nn.Module):
state_dict = OrderedDict()
for name, parameter in model.named_parameters():
if parameter.requires_grad:
state_dict[name] = parameter.detach()
return state_dict
class DDPStrategy(Strategy):
"""
Strategy for distributed training using torch.distributed.
Strategy for distributed training using torch.distributed.
"""
def __init__(self, seed: int = 42) -> None:
def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
self.seed = seed
super().__init__()
super().__init__(plugin_initializer)
def setup_distributed(self) -> None:
def _try_init_dist(self, force: bool = False) -> None:
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank)
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
if force:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
except Exception as e:
if force:
raise e
def _post_init(self) -> None:
assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
self._try_init_dist(force=True)
self.set_seed(self.seed)
torch.cuda.set_device(local_rank)
def set_seed(self, seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def setup_model(self, model: nn.Module) -> nn.Module:
device = torch.cuda.current_device()
return DDP(model, device_ids=[device])
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
# DDP only mode, replay buffers on each rank are different.
# sampler = DistributedSampler(replay_buffer,
# num_replicas=dist.get_world_size(),
# rank=dist.get_rank(),
# shuffle=True,
# seed=self.seed,
# drop_last=True)
return DataLoader(
replay_buffer,
batch_size=replay_buffer.sample_batch_size,
# sampler=sampler,
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
return self.plugin.prepare_dataloader(
data_buffer,
batch_size=data_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_model(model, path, only_rank0)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_optimizer(optimizer, path, only_rank0)
collate_fn=data_buffer.collate_fn,
)
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: DDP = super().unwrap_model(model)
return base_model.module
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_pretrained(model, path, only_rank0, tokenizer)
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
return model.unwrap()
def save_pretrained(
self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
if dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
pretrained_model = unwrapped_model.model
assert isinstance(pretrained_model, PreTrainedModel)
# HACK: only use hf save_pretrained to save config
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, shard=shard)
def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
if "requires_grad_only" in config and config["requires_grad_only"] == True:
state_dict = get_grad_required_state_dict(model)
else:
state_dict = model.state_dict()
if "shard_size" in config:
shard_size = config["shard_size"]
accumulate_size = 0
state_dict_shard = OrderedDict()
for name, param in state_dict.items():
state_dict_shard[name] = param
accumulate_size += param.numel() * param.element_size()
if accumulate_size >= shard_size:
accumulate_size = 0
yield state_dict_shard
state_dict_shard = OrderedDict()
if accumulate_size > 0:
yield state_dict_shard
else:
yield state_dict
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.optim as optim
from coati.models.base import get_base_model
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy
class NaiveStrategy(Strategy):
"""
Strategy for single GPU. No parallelism is used.
"""
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
loss.backward()
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()
def setup_distributed(self) -> None:
pass
def setup_model(self, model: nn.Module) -> nn.Module:
return model
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
return optimizer
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
return DataLoader(replay_buffer,
batch_size=replay_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
base_model = get_base_model(model)
state_dict = base_model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
base_model = get_base_model(model)
state_dict = torch.load(path, map_location=map_location)
base_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path)
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
state_dict = torch.load(path, map_location=map_location)
optimizer.load_state_dict(state_dict)
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, PreTrainedModel)
unwrapped_model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
......@@ -4,7 +4,6 @@ import numpy as np
class DistributedSampler:
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.dataset = dataset
self.num_replicas = num_replicas
......@@ -12,7 +11,7 @@ class DistributedSampler:
if len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
......@@ -20,10 +19,10 @@ class DistributedSampler:
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
indices = indices[:self.total_size]
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
self.indices = indices
......
......@@ -3,6 +3,38 @@ from typing import Any
import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
class CycledDataLoader:
"""
Why do we need this class?
In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
"""
def __init__(
self,
dataloader: DataLoader,
) -> None:
self.dataloader = dataloader
self.count = 0
self.dataloader_iter = None
def next(self):
# defer initialization
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)
self.count += 1
try:
return next(self.dataloader_iter)
except StopIteration:
self.count = 0
self.dataloader_iter = iter(self.dataloader)
return next(self.dataloader_iter)
def is_rank_0() -> bool:
......@@ -10,7 +42,6 @@ def is_rank_0() -> bool:
def to_device(x: Any, device: torch.device) -> Any:
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
......
from .tokenizer_utils import prepare_llama_tokenizer_and_embedding, smart_tokenizer_and_embedding_resize
__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']
\ No newline at end of file
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import transformers
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
def prepare_llama_tokenizer_and_embedding(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
):
"""prepare llama tokenizer and embedding.
"""
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
tokenizer.add_special_tokens({
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
})
return tokenizer
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
if tokenizer.pad_token is None:
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment