Unverified Commit df5e9c53 authored by YeAnbang's avatar YeAnbang Committed by GitHub
Browse files

[ColossalChat] Update RLHF V2 (#5286)



* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
parent 36c4bb28
from typing import Optional, Union
import torch
import torch.nn.functional as F
def _compute_approx_kl(
log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
Args:
log_probs: Log probabilities of the new distribution.
log_probs_base: Log probabilities of the base distribution.
action_mask: Mask for actions.
"""
log_ratio = log_probs_base - log_probs
approx_kl = (log_ratio.exp() - 1) - log_ratio
if action_mask is not None:
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
return approx_kl
approx_kl = approx_kl.mean(dim=1)
return approx_kl
def compute_reward(
r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
reward = r - kl_coef * kl
return reward
def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
"""Calculate action log probs.
Args:
output (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
Returns:
torch.Tensor: Action log probs.
"""
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
import os
import torch.distributed as dist
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from coati.trainer.utils import is_rank_0
from torch import nn
from torch.optim import Optimizer
from .base import Callback
class SaveCheckpoint(Callback):
"""
The callback for saving checkpoint for coati.
Only support saving actor and critic model.
A typical architecture of the saved checkpoint would be:
- checkpoint
- episode_x
- actor.pt
- actor-optim-rank-0.pt
- actor-optim-rank-1.pt
- critic.pt
- critic-optim-rank-0.pt
- critic-optim-rank-1.pt
- ...
Args:
path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
interval(int): the interval episode of saving checkpoint
strategy(Strategy): the strategy used to train
actor(nn.Module): the actor model
critic(nn.Module): the critic model
actor_optim(Optimizer): the optimizer of actor
critic_optim(Optimizer): the optimizer of critic
"""
def __init__(
self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None,
) -> None:
super().__init__()
self.path = os.path.join(path, "checkpoint")
self.interval = interval
self.strategy = strategy
self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
base_path = os.path.join(self.path, f"episode_{episode}")
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
model_path = os.path.join(base_path, f"{model}.pt")
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
if self.model_dict[model][1] is None:
continue
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank()
optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
from typing import Callable, Optional
import torch
import tqdm
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from .base import SLTrainer
from .strategies import Strategy
from .utils import is_rank_0
class RewardModelTrainer(SLTrainer):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim (Optimizer): the optimizer to use for training
lr_scheduler (_LRScheduler): the lr scheduler to use for training
loss_fn (callable): the loss function to use for training
max_epochs (int, defaults to 2): the number of epochs to train
"""
def __init__(
self,
model,
strategy: Strategy,
optim: Optimizer,
lr_scheduler: _LRScheduler,
loss_fn: Callable,
max_epochs: int = 1,
) -> None:
super().__init__(strategy, max_epochs, model, optim)
self.loss_fn = loss_fn
self.scheduler = lr_scheduler
self.num_train_step = 0
def _eval(self, epoch):
if self.eval_dataloader is not None:
self.model.eval()
dist, num_correct, num_samples = 0, 0, 0
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
num_samples += chosen_ids.size(0)
num_correct += (chosen_reward > reject_reward).sum().item()
dist += (chosen_reward - reject_reward).mean().item()
self.dist = dist / len(self.eval_dataloader)
self.acc = num_correct / num_samples
if self.writer:
self.writer.add_scalar("eval/dist", self.dist, epoch)
self.writer.add_scalar("eval/acc", self.acc, epoch)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
)
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
if self.writer:
self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
self.writer.add_scalar(
"train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
)
self.num_train_step += 1
if self.num_train_step % 100 == 0:
self.scheduler.step()
step_bar.update()
step_bar.close()
def _before_fit(
self,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
wandb.init(project="Coati-rm", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "rm")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
from .base import Strategy
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy
__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from coati.experience_buffer import ExperienceBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster.plugin import Plugin
from .sampler import DistributedSampler
_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
class Strategy(ABC):
"""
Base class for training strategies.
"""
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
super().__init__()
# NOTE: dist must be initialized before Booster
self.setup_distributed()
self.plugin = plugin_initializer()
self.booster = Booster(plugin=self.plugin)
self._post_init()
@abstractmethod
def _post_init(self) -> None:
pass
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
self.booster.backward(loss, optimizer)
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
optimizer.step()
@abstractmethod
def setup_distributed(self) -> None:
pass
@abstractmethod
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
pass
def model_init_context(self):
return nullcontext()
def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
"""Prepare [model | (model, optimizer) | Dict] based on each strategy.
NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
Example::
>>> # e.g., include lr_scheduler
>>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
>>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
>>> # or just inference
>>> actor, critic = strategy.prepare(actor, critic)
Returns:
Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
"""
rets = []
for arg in boost_args:
if isinstance(arg, nn.Module):
model, *_ = self.booster.boost(arg)
rets.append(model)
elif isinstance(arg, tuple):
try:
model, optimizer = arg
except ValueError:
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
rets.append((model, optimizer))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
boost_result = dict(
model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler,
)
# remove None values
boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result)
else:
raise RuntimeError(f"Type {type(arg)} is not supported")
return rets[0] if len(rets) == 1 else rets
@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
Args:
model (nn.Module): the model to unwrap
Returns:
nn.Module: the original model
"""
return model
def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
self.booster.save_model(model, path, shard=shard, **kwargs)
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
self.booster.load_optimizer(optimizer, path)
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, 1, 0)
@abstractmethod
def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
pass
@abstractmethod
def get_model_state_dict_shard(self, model: nn.Module, **config):
pass
import warnings
from typing import Optional
import torch.nn as nn
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy
class LowLevelZeroStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2)
precision(str): The precision to use. Choose in ('fp32', 'fp16').
seed(int): The seed for the random number generator.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
growth_interval(int): The growth interval for the optimizer.
hysteresis(int): The hysteresis for the optimizer.
min_scale(float): The minimum scale for the optimizer.
max_scale(float): The maximum scale for the optimizer.
max_norm(float): The maximum norm for the optimizer.
norm_type(float): The norm type for the optimizer.
"""
def __init__(
self,
stage: int = 2,
precision: str = "fp16",
seed: int = 42,
placement_policy: str = "cuda",
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
) -> None:
assert stage in (1, 2), f'Unsupported stage "{stage}"'
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin(
stage=stage,
precision=precision,
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == "cpu"),
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(
self.plugin, LowLevelZeroPlugin
), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, LowLevelZeroModel)
return model.module
def get_model_state_dict_shard(self, model: nn.Module, **config):
assert isinstance(model, LowLevelZeroModel)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
class GeminiStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
growth_interval(int): The growth interval for the optimizer.
hysteresis(int): The hysteresis for the optimizer.
min_scale(float): The minimum scale for the optimizer.
max_scale(float): The maximum scale for the optimizer.
max_norm(float): The maximum norm for the optimizer.
norm_type(float): The norm type for the optimizer.
"""
def __init__(
self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = "auto",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
) -> None:
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f"Shard init is not supported model.from_pretrained() yet. "
"Please load weights after strategy.prepare()"
)
self.shard_init = shard_init
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# colossalai has changed api for get_current_device in 0.3.4 version or newer
try:
from colossalai.accelerator import get_accelerator
chunk_init_device = get_accelerator().get_current_device()
except:
from colossalai.utils import get_current_device
chunk_init_device = get_current_device()
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=chunk_init_device,
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
return super().model_init_context()
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiDDP)
return model.module
import os
import random
from collections import OrderedDict
from typing import Callable, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.experience_buffer import ExperienceBuffer
from coati.models import Actor, Critic, RewardModel
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel
from .base import Strategy
from .sampler import DistributedSampler
# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
def get_grad_required_state_dict(model: nn.Module):
state_dict = OrderedDict()
for name, parameter in model.named_parameters():
if parameter.requires_grad:
state_dict[name] = parameter.detach()
return state_dict
class DDPStrategy(Strategy):
"""
Strategy for distributed training using torch.distributed.
"""
def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
self.seed = seed
super().__init__(plugin_initializer)
def _try_init_dist(self, force: bool = False) -> None:
try:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank)
except KeyError as e:
if force:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
except Exception as e:
if force:
raise e
def _post_init(self) -> None:
assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
self._try_init_dist(force=True)
self.set_seed(self.seed)
def set_seed(self, seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
return self.plugin.prepare_dataloader(
data_buffer,
batch_size=data_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=data_buffer.collate_fn,
)
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
return model.unwrap()
def save_pretrained(
self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
if dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
pretrained_model = unwrapped_model.model
assert isinstance(pretrained_model, PreTrainedModel)
# HACK: only use hf save_pretrained to save config
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, shard=shard)
def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
if "requires_grad_only" in config and config["requires_grad_only"] == True:
state_dict = get_grad_required_state_dict(model)
else:
state_dict = model.state_dict()
if "shard_size" in config:
shard_size = config["shard_size"]
accumulate_size = 0
state_dict_shard = OrderedDict()
for name, param in state_dict.items():
state_dict_shard[name] = param
accumulate_size += param.numel() * param.element_size()
if accumulate_size >= shard_size:
accumulate_size = 0
yield state_dict_shard
state_dict_shard = OrderedDict()
if accumulate_size > 0:
yield state_dict_shard
else:
yield state_dict
import math
import numpy as np
class DistributedSampler:
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
if len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
self.indices = indices
def sample(self, batch_size: int) -> list:
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
return [self.dataset[idx] for idx in sampled_indices]
# Examples
## Table of Contents
- [Examples](#examples)
- [Table of Contents](#table-of-contents)
- [Install requirements](#install-requirements)
- [Supervised datasets collection](#supervised-datasets-collection)
- [Conversation dataset generation](#conversation-dataset-generation)
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
- [Arg List](#arg-list)
- [Stage2 - Training reward model](#stage2---training-reward-model)
- [Features and tricks in RM training](#features-and-tricks-in-rm-training)
- [Experiment result](#experiment-result)
- [Arg List](#arg-list-1)
- [Stage3 - Training model using prompts with RL](#stage3---training-model-using-prompts-with-rl)
- [Arg List](#arg-list-2)
- [Inference example - After Stage3](#inference-example---after-stage3)
- [Attention](#attention)
- [data](#data)
- [Support Model](#support-model)
- [GPT](#gpt)
- [BLOOM](#bloom)
- [OPT](#opt)
- [LLaMA](#llama)
- [Add your own models](#add-your-own-models)
- [Actor model](#actor-model)
- [Reward model](#reward-model)
- [Critic model](#critic-model)
---
## Install requirements
```shell
pip install -r requirements.txt
```
## Supervised datasets collection
We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo
[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md).
Here is how we collected the data
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
</p>
### Conversation dataset generation
In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.
A sample of conversation dataset should have the following fields:
- `type` (str, optional): The type of the data sample.
- `language` (str, optional): The language of the data sample.
- `dataset` (str, optional): The dataset the data sample originates from.
- `conversations` (str, compulsory): Conversation content of the data sample.
- `id` (int, optional): The ID of the data sample.
A simple example:
```json
{
"type": "instruction",
"language": "English",
"dataset": "Alpaca",
"conversations": [
{
"from": "human",
"value": "Give three tips for staying healthy."
},
{
"from": "gpt",
"value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
}
],
"id": 1
}
```
> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.
You can use the following cmd to generate conversation dataset.
```bash
python generate_conversation_dataset.py \
--dataset "All"
--save_path "/path/to/dataset"
```
## Stage1 - Supervised instructs tuning
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg)
You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning.
You can also use the following cmd to start a supervised instructs fine-tuning with your own settings.
```bash
torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2 \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 4 \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1 \
--grad_checkpoint
```
**Note**: the supervised dataset follows the following format,
```json
[
{
"instruction": "Provide a list of the top 10 most popular mobile games in Asia",
"input": "",
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id": 0
},
...
]
```
### Arg List
- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
- `--pretrain`: pretrain model, type=str, default=None
- `--max_datasets_size`: the max size of dataset, type=int, default=None
- `--save_path`: path to save the model, type=str, default='output'
- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
- `--max_epochs`: max epochs for training, type=int, default=3
- `--batch_size`: batch size while training, type=int, default=4
- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
- `--grad_checkpoint`: enable gradient checkpointing, type=bool, default=False
## Stage2 - Training reward model
We train a reward model in stage 2, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo)
You can run the `examples/train_rm.sh` to start a reward model training.
You can also use the following cmd to start training a reward model.
```bash
torchrun --standalone --nproc_per_node=4 train_reward_model.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2 \
--loss_fn 'log_exp'\
--save_path 'rmstatic.pt' \
```
### Features and tricks in RM training
- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets.
- We support 2 kinds of loss function named `log_sig`(used by OpenAI) and `log_exp`(used by Anthropic).
- We change the loss to `valid_acc` and `pair_dist` to monitor progress during training.
- We add special token to the end of the sequence to get better result.
- We use cosine-reducing lr-scheduler for RM training.
- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2204.05862).
### Experiment result
Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):
<div align=middle> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225263321-8d64c3a8-6877-4cc8-9b61-0e1c52d3d94f.png">
<div align=left>Our training & test result of bloom-560m for 1 epoch:
<div align=middle> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225262950-a7f0a686-25de-44ec-98f2-11b83ea86674.png">
<div align=left>We also train the reward model based on LLaMA-7B, which reaches the ACC of 72.06% after 1 epoch, performing almost the same as Anthropic's best RM.
### Arg List
- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
- `--pretrain`: pretrain model, type=str, default=None
- `--model_path`: the path of rm model(if continue to train), type=str, default=None
- `--save_path`: path to save the model, type=str, default='output'
- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
- `--max_epochs`: max epochs for training, type=int, default=3
- `--dataset`: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static']
- `--subset`: subset of the dataset, type=str, default=None
- `--batch_size`: batch size while training, type=int, default=4
- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
- `--loss_func`: which kind of loss function, choices=['log_sig', 'log_exp']
- `--max_len`: max sentence length for generation, type=int, default=512
## Stage3 - Training model using prompts with RL
Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process, as shown below:
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/stage-3.jpeg" width=800/>
</p>
You can run the `examples/train_prompts.sh` to start PPO training.
You can also use the cmd following to start PPO training.
[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g)
```bash
torchrun --standalone --nproc_per_node=4 train_prompts.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2 \
--prompt_dataset /path/to/your/prompt_dataset \
--pretrain_dataset /path/to/your/pretrain_dataset \
--rm_pretrain /your/pretrain/rm/definition \
--rm_path /your/rm/model/path
```
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset.
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
**Note**: the required datasets follow the following format,
- `pretrain dataset`
```json
[
{
"instruction": "Provide a list of the top 10 most popular mobile games in Asia",
"input": "",
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id": 0
},
...
]
```
- `prompt dataset`
```json
[
{
"instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
"id": 0
},
{
"instruction": "Write a descriptive paragraph about a memorable vacation you went on",
"id": 1
},
...
]
```
### Arg List
- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
- `--model`: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
- `--pretrain`: pretrain model, type=str, default=None
- `--rm_model`: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None
- `--rm_pretrain`: pretrain model for reward model, type=str, default=None
- `--rm_path`: the path of rm model, type=str, default=None
- `--save_path`: path to save the model, type=str, default='output'
- `--prompt_dataset`: path of the prompt dataset, type=str, default=None
- `--pretrain_dataset`: path of the ptx dataset, type=str, default=None
- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
- `--num_episodes`: num of episodes for training, type=int, default=10
- `--num_update_steps`: number of steps to update policy per episode, type=int
- `--num_collect_steps`: number of steps to collect experience per episode, type=int
- `--train_batch_size`: batch size while training, type=int, default=8
- `--ptx_batch_size`: batch size to compute ptx loss, type=int, default=1
- `--experience_batch_size`: batch size to make experience, type=int, default=8
- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
- `--kl_coef`: kl_coef using for computing reward, type=float, default=0.1
- `--ptx_coef`: ptx_coef using for computing policy loss, type=float, default=0.9
## Inference example - After Stage3
We support different inference options, including int8 and int4 quantization.
For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## Attention
The examples are demos for the whole training process.You need to change the hyper-parameters to reach great performance.
#### data
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
- [ ] [openai/webgpt_comparisons](https://huggingface.co/datasets/openai/webgpt_comparisons)
- [ ] [Dahoas/instruct-synthetic-prompt-responses](https://huggingface.co/datasets/Dahoas/instruct-synthetic-prompt-responses)
## Support Model
### GPT
- [x] GPT2-S (s)
- [x] GPT2-M (m)
- [x] GPT2-L (l)
- [x] GPT2-XL (xl)
- [x] GPT2-4B (4b)
- [ ] GPT2-6B (6b)
### BLOOM
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloom-7b1)
- [ ] [BLOOM-175b](https://huggingface.co/bigscience/bloom)
### OPT
- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
- [x] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
- [x] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b)
- [x] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b)
- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b)
- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
### [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
- [x] LLaMA-7B
- [x] LLaMA-13B
- [ ] LLaMA-33B
- [ ] LLaMA-65B
## Add your own models
If you want to support your own model in Coati, please refer the pull request for RoBERTa support as an example --[[chatgpt] add pre-trained model RoBERTa for RLHF stage 2 & 3](https://github.com/hpcaitech/ColossalAI/pull/3223), and submit a PR to us.
You should complete the implementation of four model classes, including Reward model, Critic model, LM model, Actor model
here are some example code for a NewModel named `Coati`.
if it is supported in huggingface [transformers](https://github.com/huggingface/transformers), you can load it by `from_pretrained`, o
r you can build your own model by yourself.
### Actor model
```python
from ..base import Actor
from transformers.models.coati import CoatiModel
class CoatiActor(Actor):
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = CoatiModel.from_pretrained(pretrained)
else:
model = build_model() # load your own model if it is not support in transformers
super().__init__(model, lora_rank, lora_train_bias)
```
### Reward model
```python
from ..base import RewardModel
from transformers.models.coati import CoatiModel
class CoatiRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = CoatiModel.from_pretrained(pretrained)
else:
model = build_model() # load your own model if it is not support in transformers
value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
```
### Critic model
```python
from ..base import Critic
from transformers.models.coati import CoatiModel
class CoatiCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = CoatiModel.from_pretrained(pretrained)
else:
model = build_model() # load your own model if it is not support in transformers
value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
```
import argparse
import dataclasses
import os
import parser
from typing import List
import tqdm
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer
@dataclasses.dataclass
class HFRepoFiles:
repo_id: str
files: List[str]
def download(self, dir_path: str):
for file in self.files:
file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
def download_all(self):
snapshot_download(self.repo_id)
def test_init(model: str, dir_path: str):
if model == "gpt2":
config = GPT2Config.from_pretrained(dir_path)
actor = GPTActor(config=config)
critic = GPTCritic(config=config)
reward_model = GPTRM(config=config)
GPT2Tokenizer.from_pretrained(dir_path)
elif model == "bloom":
config = BloomConfig.from_pretrained(dir_path)
actor = BLOOMActor(config=config)
critic = BLOOMCritic(config=config)
reward_model = BLOOMRM(config=config)
BloomTokenizerFast.from_pretrained(dir_path)
elif model == "opt":
config = AutoConfig.from_pretrained(dir_path)
actor = OPTActor(config=config)
critic = OPTCritic(config=config)
reward_model = OPTRM(config=config)
AutoTokenizer.from_pretrained(dir_path)
else:
raise NotImplementedError(f"Model {model} not implemented")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-dir", type=str, default="test_models")
parser.add_argument("--config-only", default=False, action="store_true")
args = parser.parse_args()
if os.path.exists(args.model_dir):
print(f"[INFO]: {args.model_dir} already exists")
exit(0)
repo_list = {
"gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
"bloom": HFRepoFiles(
repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
),
"opt": HFRepoFiles(
repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
),
}
os.mkdir(args.model_dir)
for model_name in tqdm.tqdm(repo_list):
dir_path = os.path.join(args.model_dir, model_name)
if args.config_only:
os.mkdir(dir_path)
repo_list[model_name].download(dir_path)
else:
repo_list[model_name].download_all()
test_init(model_name, dir_path)
import argparse
import json
from datasets import load_dataset
def generate_alpaca():
# We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
conversation_dataset = []
dataset = load_dataset("tatsu-lab/alpaca", split="train")
instructions = dataset["instruction"]
inputs = dataset["input"]
outputs = dataset["output"]
assert len(instructions) == len(inputs) == len(outputs)
for idx in range(len(instructions)):
human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
human = {"from": "human", "value": human_utterance}
gpt_utterance = outputs[idx]
gpt = {"from": "gpt", "value": gpt_utterance}
conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
conversation_dataset.append(conversation)
return conversation_dataset
def generate_sharegpt():
# ShareGPT data requires less processing.
conversation_dataset = []
dataset = load_dataset(
"anon8231489123/ShareGPT_Vicuna_unfiltered",
data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
split="train",
)
conversations = dataset["conversations"]
for idx in range(len(conversations)):
for conv in conversations[idx]:
# We don't need markdown and text value.
del conv["markdown"]
del conv["text"]
conversation = dict(
type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
)
conversation_dataset.append(conversation)
return conversation_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="All",
choices=["Alpaca", "ShareGPT", "All"],
help="which dataset to convert, All will combine Alpaca and ShareGPT",
)
parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
args = parser.parse_args()
conversation_dataset = []
if args.dataset == "Alpaca":
conversation_dataset.extend(generate_alpaca())
elif args.dataset == "ShareGPT":
conversation_dataset.extend(generate_sharegpt())
else:
conversation_dataset.extend(generate_alpaca())
conversation_dataset.extend(generate_sharegpt())
for idx, sample in enumerate(conversation_dataset):
sample["id"] = idx + 1
with open(args.save_path, mode="w") as f:
json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
import argparse
import json
import random
random.seed(42)
def sample(args):
with open(args.dataset_path, mode="r") as f:
dataset_list = json.load(f)
sampled_dataset = [
{"instruction": sample["instruction"], "id": idx}
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
]
with open(args.save_path, mode="w") as f:
json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
args = parser.parse_args()
sample(args)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment