Unverified Commit df5e9c53 authored by YeAnbang's avatar YeAnbang Committed by GitHub
Browse files

[ColossalChat] Update RLHF V2 (#5286)



* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
parent 36c4bb28
from .conversation import Conversation, setup_conversation_template
from .loader import (
DataCollatorForPreferenceDataset,
DataCollatorForPromptDataset,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf
__all__ = [
"tokenize_prompt_dataset",
"DataCollatorForPromptDataset",
"is_rank_0",
"DataCollatorForPreferenceDataset",
"DataCollatorForSupervisedDataset",
"StatefulDistributedSampler",
"load_tokenized_dataset",
"setup_distributed_dataloader",
"supervised_tokenize_pretrain",
"supervised_tokenize_sft",
"tokenize_rlhf",
"setup_conversation_template",
"Conversation",
]
import dataclasses
import json
import os
from typing import Any, Dict, List
import torch.distributed as dist
from transformers import AutoTokenizer, PreTrainedTokenizer
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
@dataclasses.dataclass
class Conversation:
tokenizer: PreTrainedTokenizer
system_message: str
chat_template: str
stop_ids: List[int]
@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
"""
Setup the conversation template from config
"""
tokenizer.chat_template = config["chat_template"]
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"])
conv.clear()
return conv
def clear(self):
self.messages = []
@classmethod
def get_conversation_template_keys(cls):
return ["system_message", "chat_template"]
def __str__(self):
return json.dumps(
{k: self.__dict__[k] for k in self.__dict__ if k not in ["tokenizer", "messages"]},
ensure_ascii=False,
indent=4,
)
def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any:
"""
Retrieves the prompt for the conversation.
Args:
length (int, optional): The number of messages to include in the prompt. Defaults to None.
get_seps_info (bool, optional): Whether to include separator information in the output. Defaults to False.
add_generation_prompt (bool, optional): Whether to add the assistant line start token in generation (for generation only). Defaults to False.
Returns:
str or tuple: The prompt string if get_seps_info is False, otherwise a tuple containing the prompt string and separator information.
"""
if length is None:
length = len(self.messages)
assert length <= len(self.messages)
if self.system_message is not None:
messages = [{"role": "system", "content": self.system_message}] + self.messages[:length]
else:
messages = self.messages[:length]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
return prompt
def save_prompt(self):
return self.get_prompt()
def append_message(self, role: str, message: str):
"""
Append a message to the conversation.
Args:
role (str): The role of the message sender. Must be either 'user' or 'assistant'.
message (str): The content of the message.
Raises:
AssertionError: If the role is not 'user' or 'assistant'.
"""
assert role in ["user", "assistant"]
self.messages.append({"role": role, "content": message})
def copy(self):
return Conversation(tokenizer=self.tokenizer, chat_template=self.chat_template)
def setup_conversation_template(
tokenizer: PreTrainedTokenizer, chat_template_config: Dict = None, save_path: str = None
) -> Conversation:
"""
Setup the conversation template, if chat_template is given, will replace the default chat_template of the tokenizer
with it. Otherwise, the default chat_template will be used. If the tokenizer doesn't have a default chat_template,
raise error to remind the user to set it manually.
Args:
tokenizer: The tokenizer to use
chat_template_config:
{
"system_message": str The system message to use
"chat_template": str The chat_template to use, if can be a chat_template, a huggingface model path or a local model.
if a huggeface model path or a local model, the chat_template will be loaded from the model's tokenizer's default chat template.
"stop_ids": List[int], the token ids used to terminate generation. You need to provide this for ppo training and generation.
}
"""
if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]):
# Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually
if "system_message" not in chat_template_config:
logger.warning("No system message is provided, will not use system message.")
if "chat_template" not in chat_template_config:
logger.warning("No chat_template is provided, will try to load it from the tokenizer.")
if tokenizer.chat_template != None:
chat_template_config["chat_template"] = tokenizer.chat_template
else:
raise ValueError(
f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually."
)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(chat_template_config["chat_template"])
if tokenizer.chat_template != None:
chat_template_config["chat_template"] = tokenizer.chat_template
else:
raise ValueError(
f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually."
)
logger.warning(
f"chat_template is provided as a local model path or huggingface model path, loaded chat_template from \"{chat_template_config['chat_template']}\"."
)
except OSError:
pass
except ValueError as e:
raise ValueError(e)
if not dist.is_initialized() or dist.get_rank() == 0:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
json.dump(chat_template_config, f, indent=4, ensure_ascii=False)
return Conversation.from_config(tokenizer, chat_template_config)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Dataloader for sft, dpo, ppo
"""
import math
import os
import random
from dataclasses import dataclass
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
def load_tokenized_dataset(
dataset_paths: Union[PathType, List[PathType]], mode: str = "train", **kwargs
) -> Optional[DatasetType]:
"""
Load pre-tokenized dataset.
Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
"""
mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"})
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
if isinstance(dataset_paths, (str, os.PathLike)):
dataset_paths = [dataset_paths]
datasets = [] # `List[datasets.dataset_dict.Dataset]`
for ds_path in dataset_paths:
ds_path = os.path.abspath(ds_path)
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
if isinstance(ds_dict, HFDataset):
datasets.append(ds_dict)
else:
if mode_map[mode] in ds_dict:
datasets.append(ds_dict[mode_map[mode]])
if len(datasets) == 0:
return None
if len(datasets) == 1:
return datasets.pop()
return ConcatDataset(datasets=datasets)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""
Collate instances for supervised dataset.
Each instance is a tokenized dictionary with fields
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
# `List[torch.Tensor]`
batch_input_ids = [
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
for instance in instances
]
batch_labels = [
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
for instance in instances
]
if self.tokenizer.padding_side == "right":
input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=batch_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
labels = torch.nn.utils.rnn.pad_sequence(
sequences=batch_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
# pad to max
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
elif self.tokenizer.padding_side == "left":
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
reversed_labels = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
else:
raise RuntimeError(
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
f"but now `{self.tokenizer.padding_side}`"
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
@dataclass
class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset):
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
"""
instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
ret = super().__call__(instances=instances)
input_ids = F.pad(
ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
)
attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@dataclass
class DataCollatorForPreferenceDataset(object):
"""
Collate instances for supervised dataset.
Each instance is a tokenized dictionary with fields
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
(
chosen_input_ids,
chosen_loss_mask, # [batch_size * seq_len]
reject_input_ids,
reject_loss_mask,
) = (
chuncate_sequence([ins["chosen_input_ids"] for ins in instances], self.max_length, torch.int64),
chuncate_sequence([ins["chosen_loss_mask"] for ins in instances], self.max_length, torch.bool),
chuncate_sequence([ins["rejected_input_ids"] for ins in instances], self.max_length, torch.int64),
chuncate_sequence([ins["rejected_loss_mask"] for ins in instances], self.max_length, torch.bool),
)
padding_side = self.tokenizer.padding_side
chosen_attention_mask = [torch.ones_like(seq).bool() for seq in chosen_input_ids]
reject_attention_mask = [torch.ones_like(seq).bool() for seq in reject_input_ids]
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
pad_to_max_len(chosen_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side),
pad_to_max_len(chosen_attention_mask, self.max_length, False, padding_side=padding_side),
pad_to_max_len(chosen_loss_mask, self.max_length, False, padding_side=padding_side),
pad_to_max_len(reject_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side),
pad_to_max_len(reject_attention_mask, self.max_length, False, padding_side=padding_side),
pad_to_max_len(reject_loss_mask, self.max_length, False, padding_side=padding_side),
)
return dict(
chosen_input_ids=chosen_input_ids,
chosen_attention_mask=chosen_attention_mask,
chosen_loss_mask=chosen_loss_mask,
reject_input_ids=reject_input_ids,
reject_attention_mask=reject_attention_mask,
reject_loss_mask=reject_loss_mask,
)
class StatefulDistributedSampler(DistributedSampler):
"""
Stateful distributed sampler for multi-stage training.
"""
def __init__(
self,
dataset: DatasetType,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
use_tp: Optional[bool] = False,
) -> None:
if not use_tp:
super().__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
else:
# adapted from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L62
# TODO: support tp_group>1. will fix it later
num_replicas = 1
if rank is None:
rank = dist.get_rank()
if rank < 0:
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, 0]")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
self.start_index = 0
self.use_tp = use_tp
def __iter__(self) -> Iterator:
if self.use_tp:
# TODO Add support for tp_group not equal to 1
pass
# adpated from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L96
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[
: self.total_size : self.num_replicas
] # num_replicas=tp_group=1, we only support tp_group==1 for now
assert len(indices) == self.num_samples
return iter(indices)
else:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def setup_distributed_dataloader(
dataset: DatasetType,
batch_size: int = 1,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
process_group: Optional[ProcessGroup] = None,
use_tp: Optional[bool] = False,
**kwargs,
) -> DataLoader:
"""
Setup dataloader for distributed training.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset=dataset,
num_replicas=process_group.size() if not use_tp else 1,
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
use_tp=use_tp,
)
# Deterministic dataloader
def seed_worker(worker_id: int) -> None:
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
worker_init_fn=seed_worker,
**_kwargs,
)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tokenization utils for constructing dataset for ppo, dpo, sft, rm
"""
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Union
from coati.dataset.conversation import Conversation
from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
from datasets import dataset_dict
from torch.utils.data import ConcatDataset, Dataset
from transformers import PreTrainedTokenizer
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
IGNORE_INDEX = -100
DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
def supervised_tokenize_sft(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following
and calculate corresponding labels for sft training:
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line end]Something here"
^
end_of_system_line_position
Args:
data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
if len(template.messages) % 2 != 0:
template.messages = template.messages[0:-1]
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
turns = [i for i in range(1, len(messages) // 2 + 1)]
lo, hi = 0, len(turns)
while lo < hi:
mid = (lo + hi) // 2
if max_length - 1 < len(
tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0]
):
hi = mid
else:
lo = mid + 1
target_turn_index = lo
# The tokenized length for first turn already exceeds `max_length - 1`.
if target_turn_index - 1 < 0:
warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
return dict(
input_ids=None,
labels=None,
inputs_decode=None,
labels_decode=None,
seq_length=None,
seq_category=None,
)
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
labels = [ignore_index] * len(tokenized)
label_decode = []
for start, end in zip(starts, ends):
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [ignore_index]
labels[start : end + 1] = tokenized[start : end + 1]
label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
labels = [ignore_index] + labels
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [tokenizer.eos_token_id]
else:
labels[-1] = tokenizer.eos_token_id
# For some model without bos/eos may raise the following errors
try:
inputs_decode = tokenizer.decode(tokenized)
except TypeError as e:
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
# Check if all labels are ignored, this may happen when the tokenized length is too long
if labels.count(ignore_index) == len(labels):
return dict(
input_ids=None,
labels=None,
inputs_decode=None,
labels_decode=None,
seq_length=None,
seq_category=None,
)
return dict(
input_ids=tokenized,
labels=labels,
inputs_decode=inputs_decode,
labels_decode=label_decode,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
def tokenize_prompt_dataset(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following for ppo training:
"Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
Args:
data_point: the data point of the following format
{"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
tokenizer: the tokenizer whose
conversation_template: the conversation template to apply
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
for mess in messages:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
template.append_message(from_str, mess["content"])
# `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
target_turn = len(template.messages)
if target_turn % 2 != 1:
# exclude the answer if provided. keep only the prompt
target_turn = target_turn - 1
# Prepare data
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
# Skip overlength data
if max_length - 1 < len(tokenized):
return dict(
input_ids=None,
inputs_decode=None,
seq_length=None,
seq_category=None,
)
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
inputs_decode=tokenizer.decode(tokenized),
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
def apply_rlhf_data_format(
template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
if mask_token is None:
mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen
label_decode = []
for start, end in zip(starts[-1:], ends[-1:]):
# only the last round (chosen/rejected) counts
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
loss_mask[start : end + 1] = [1] * len(loss_mask[start : end + 1])
label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
loss_mask = [0] + loss_mask
if tokenizer.eos_token_id is not None:
# Force to add eos token at the end of the tokenized sequence
if tokenized[-1] != tokenizer.eos_token_id:
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
else:
loss_mask[-1] = 1
return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
def tokenize_rlhf(
data_point: Dict[str, str],
tokenizer: PreTrainedTokenizer,
conversation_template: Conversation = None,
ignore_index: int = None,
max_length: int = 4096,
) -> Dict[str, Union[int, str, List[int]]]:
"""
A tokenization function to tokenize an original pretraining data point as following:
{"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
"chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
"""
if ignore_index is None:
ignore_index = IGNORE_INDEX
context = data_point["context"]
template = deepcopy(conversation_template)
template.clear()
for mess in context:
from_str = mess["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
# Concate adjacent message from the same role
template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
else:
template.append_message(from_str, mess["content"])
if len(template.messages) % 2 != 1:
warnings.warn(
"Please make sure leading context starts and ends with a line from human\nLeading context: "
+ str(template.messages)
)
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
chosen_label_decode=None,
rejected_input_ids=None,
rejected_loss_mask=None,
rejected_label_decode=None,
)
round_of_context = int((len(template.messages) - 1) / 2)
assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
chosen = deepcopy(template)
rejected = deepcopy(template)
for round in range(len(data_point["chosen"])):
from_str = data_point["chosen"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
chosen.append_message(from_str, data_point["chosen"][round]["content"])
for round in range(len(data_point["rejected"])):
from_str = data_point["rejected"][round]["from"]
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
from_str = "assistant"
else:
raise ValueError(f"Unsupported role {from_str.lower()}")
rejected.append_message(from_str, data_point["rejected"][round]["content"])
(
chosen_input_ids,
chosen_loss_mask,
chosen_label_decode,
rejected_input_ids,
rejected_loss_mask,
rejected_label_decode,
) = (None, None, None, None, None, None)
if (
len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0])
<= max_length - 1
and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0])
<= max_length - 1
):
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
chosen_data_packed["input_ids"],
chosen_data_packed["loss_mask"],
chosen_data_packed["label_decode"],
)
rejected_data_packed = apply_rlhf_data_format(
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
)
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
rejected_data_packed["input_ids"],
rejected_data_packed["loss_mask"],
rejected_data_packed["label_decode"],
)
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask):
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
chosen_label_decode=None,
rejected_input_ids=None,
rejected_loss_mask=None,
rejected_label_decode=None,
)
return {
"chosen_input_ids": chosen_input_ids,
"chosen_loss_mask": chosen_loss_mask,
"chosen_label_decode": chosen_label_decode,
"rejected_input_ids": rejected_input_ids,
"rejected_loss_mask": rejected_loss_mask,
"rejected_label_decode": rejected_label_decode,
}
else:
return dict(
chosen_input_ids=None,
chosen_loss_mask=None,
chosen_label_decode=None,
rejected_input_ids=None,
rejected_loss_mask=None,
rejected_label_decode=None,
)
import io
import json
from typing import Any, Dict, List
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict
def read_string_by_schema(data: Dict[str, Any], schema: str) -> str:
"""
Read a feild of the dataset be schema
Args:
data: Dict[str, Any]
schema: cascaded feild names seperated by '.'. e.g. person.name.first will access data['person']['name']['first']
"""
keys = schema.split(".")
result = data
for key in keys:
result = result.get(key, None)
if result is None:
return ""
assert isinstance(result, str), f"dataset element is not a string: {result}"
return result
def pad_to_max_len(
sequence: List[torch.Tensor], max_length: int, padding_value: int, batch_first: bool = True, padding_side="left"
):
"""
Args:
sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True
"""
if padding_side == "left":
reversed_sequence = [seq.flip(dims=(0,)) for seq in sequence]
padded = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_sequence, batch_first=batch_first, padding_value=padding_value
)
to_pad = max_length - padded.size(1)
padded = F.pad(padded, (0, to_pad), value=padding_value)
return torch.flip(padded, dims=(1,))
elif padding_side == "right":
padded = torch.nn.utils.rnn.pad_sequence(
sequences=sequence, batch_first=batch_first, padding_value=padding_value
)
to_pad = max_length - padded.size(1)
return F.pad(padded, (0, to_pad), value=padding_value)
else:
raise RuntimeError(f"`padding_side` can only be `left` or `right`, " f"but now `{padding_side}`")
def chuncate_sequence(sequence: List[torch.Tensor], max_length: int, dtype: Any):
"""
Args:
sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True
"""
return [
torch.Tensor(seq[:max_length]).to(dtype) if len(seq) > max_length else torch.Tensor(seq).to(dtype)
for seq in sequence
]
def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, start_index: int = 0) -> int:
if subseq is None:
return 0
for i in range(start_index, len(seq) - len(subseq) + 1):
if torch.all(seq[i : i + len(subseq)] == subseq):
return i
return -1
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
"""
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
text (List[str]): The list of texts to tokenize.
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
Returns:
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
the start positions of loss spans, and the end positions of loss spans.
"""
input_ids = []
loss_starts = []
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
return input_ids, loss_starts, loss_ends
def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str):
# Seperate templated prompt into chunks by human/assistant's lines, prepare data for tokenize_and_concatenate
start_idx = 0
chunks = []
require_loss = []
for line in messages:
first_occur = prompt.find(line["content"], start_idx)
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + len(line["content"])])
else:
chunks.append(prompt[start_idx : first_occur - 1])
chunks.append(prompt[first_occur - 1 : first_occur + len(line["content"])])
start_idx = first_occur + len(line["content"])
if line["role"].lower() == "assistant":
require_loss.append(False)
require_loss.append(True)
else:
require_loss.append(False)
require_loss.append(False)
chunks.append(prompt[start_idx:])
require_loss.append(False)
return chunks, require_loss
import random
import warnings
from typing import List
import torch
from coati.experience_maker.base import Experience
from colossalai.logging import get_dist_logger
from .base import ExperienceBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch
logger = get_dist_logger()
class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive experience buffer class. It stores experience.
......@@ -35,7 +38,7 @@ class NaiveExperienceBuffer(ExperienceBuffer):
if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:]
def clear(self) -> None:
......@@ -43,6 +46,12 @@ class NaiveExperienceBuffer(ExperienceBuffer):
@torch.no_grad()
def sample(self) -> Experience:
"""
Randomly samples experiences from the buffer.
Returns:
A batch of sampled experiences.
"""
items = random.sample(self.items, self.sample_batch_size)
experience = make_experience_batch(items)
if self.cpu_offload:
......
......@@ -26,6 +26,7 @@ class BufferItem:
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
kl: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
......@@ -34,7 +35,7 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
keys = ("sequences", "action_log_probs", "values", "reward", "kl", "advantages", "attention_mask", "action_mask")
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
......@@ -63,7 +64,7 @@ def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> to
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(("action_log_probs", "action_mask"))
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
keys = ("sequences", "action_log_probs", "values", "reward", "kl", "advantages", "attention_mask", "action_mask")
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
......
......@@ -3,7 +3,8 @@ from dataclasses import dataclass
from typing import Optional
import torch
from coati.models.base import Actor, Critic, RewardModel
from coati.models import Critic, RewardModel
from transformers import PreTrainedModel
@dataclass
......@@ -28,6 +29,7 @@ class Experience:
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
kl: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
......@@ -39,6 +41,7 @@ class Experience:
self.values = self.values.to(device)
self.reward = self.reward.to(device)
self.advantages = self.advantages.to(device)
self.kl = self.kl.to(device)
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(device)
if self.action_mask is not None:
......@@ -50,6 +53,7 @@ class Experience:
self.values = self.values.pin_memory()
self.reward = self.reward.pin_memory()
self.advantages = self.advantages.pin_memory()
self.kl = self.kl.pin_memory()
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.pin_memory()
if self.action_mask is not None:
......@@ -58,7 +62,13 @@ class Experience:
class ExperienceMaker(ABC):
def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
"""
Base class for experience makers.
"""
def __init__(
self, actor: PreTrainedModel, critic: Critic, reward_model: RewardModel, initial_model: PreTrainedModel
) -> None:
super().__init__()
self.actor = actor
self.critic = critic
......@@ -67,4 +77,14 @@ class ExperienceMaker(ABC):
@abstractmethod
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
pass
"""
Abstract method to generate an experience.
Args:
input_ids (torch.Tensor): The input tensor.
attention_mask (torch.Tensor): The attention mask tensor.
**generate_kwargs: Additional keyword arguments for generating the experience.
Returns:
Experience: The generated experience.
"""
"""
experience maker.
"""
import torch
import torch.nn.functional as F
from coati.dataset.utils import find_first_occurrence_subsequence
from coati.models import Critic, RewardModel
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from transformers import PreTrainedModel, PreTrainedTokenizer
from colossalai.logging import get_dist_logger
from .base import Experience, ExperienceMaker
logger = get_dist_logger()
import torch.distributed as dist
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
class NaiveExperienceMaker(ExperienceMaker):
"""
Naive experience maker.
"""
def __init__(
self,
actor: PreTrainedModel,
critic: Critic,
reward_model: RewardModel,
initial_model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
kl_coef: float = 0.01,
gamma: float = 1.0,
lam: float = 0.95,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef
self.gamma = gamma
self.lam = lam
@torch.no_grad()
def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
"""
Calculates the advantage values for each action based on the value and reward tensors.
Args:
value (torch.Tensor): Tensor containing the predicted values from critic.
reward (torch.Tensor): reward of the shape [B, len].
num_actions (int): Number of actions.
Returns:
torch.Tensor: Tensor containing the calculated advantages for each action.
"""
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(num_actions)):
nextvalues = value[:, t + 1] if t < num_actions - 1 else 0.0
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
return advantages
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
"""
Generates an experience using the given input_ids and attention_mask.
Args:
input_ids (torch.Tensor): The input tensor containing the tokenized input sequence.
attention_mask (torch.Tensor): The attention mask tensor indicating which tokens to attend to.
**generate_kwargs: Additional keyword arguments for the generation process.
Returns:
Experience: The generated experience object.
"""
self.actor.eval()
self.critic.eval()
self.initial_model.eval()
self.reward_model.eval()
pad_token_id = self.tokenizer.pad_token_id
stop_token_ids = generate_kwargs.get("stop_token_ids", None)
torch.manual_seed(41) # for tp, gurantee the same input for reward model
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
# Pad to max length
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
sequence_length = sequences.size(1)
# Calculate auxiliary tensors
attention_mask = None
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
if stop_token_ids is None:
# End the sequence with eos token
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# Left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
else:
# stop_token_ids are given, generation ends with stop_token_ids
action_mask = torch.ones_like(sequences, dtype=torch.bool)
for i in range(sequences.size(0)):
stop_index = find_first_occurrence_subsequence(
sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device)
)
if stop_index == -1:
# Sequence does not contain stop_token_ids, this should never happen BTW
logger.warning(
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
)
else:
# Keep stop tokens
stop_index = input_len + stop_index
action_mask[i, stop_index + len(stop_token_ids) :] = False
generation_end_index = (action_mask == True).sum(dim=-1) - 1
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
# Convert to right padding for the reward model and the critic model
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
for i in range(sequences.size(0)):
sequence = sequences[i]
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
eos_index = generation_end_index[i]
sequence_to_pad = sequence[bos_index:eos_index]
sequence_padded = F.pad(
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
)
input_ids_rm[i] = sequence_padded
if sequence_length - sequence_to_pad.size(0) > 0:
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
else:
attention_mask_rm[i, :] = 1
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
r = self.reward_model(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
value = self.critic(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
value = value[:, -num_actions:] * action_mask
advantages = self.calculate_advantage(value, reward, num_actions)
advantages = advantages.detach()
value = value.detach()
r = r.detach()
return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask)
from .base import BaseModel
from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
from .lora import convert_to_lora_module
from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .utils import disable_dropout
__all__ = [
"BaseModel",
"Critic",
"RewardModel",
"PolicyLoss",
"ValueLoss",
"LogSigLoss",
"LogExpLoss",
"convert_to_lora_module",
"DpoLoss",
"generate",
"generate_streaming",
"disable_dropout",
"update_model_kwargs_fn",
"prepare_inputs_fn",
]
"""
Base class for critic and reward model
"""
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoModel, PretrainedConfig
class BaseModel(nn.Module):
"""
Actor model base class.
Args:
pretrained (str): path to pretrained model.
config (PretrainedConfig): PretrainedConfig used to initiate the base model.
**kwargs: all other kwargs as in AutoModel.from_pretrained
"""
def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
super().__init__()
if pretrained is not None:
if config is not None:
# initialize with config and load weights from pretrained
self.model = AutoModel.from_pretrained(pretrained, config=config, **kwargs)
else:
# initialize with pretrained
self.model = AutoModel.from_pretrained(pretrained, **kwargs)
elif config is not None:
# initialize with config
self.model = AutoModel.from_config(config, **kwargs)
else:
raise ValueError("Either pretrained or config must be provided.")
self.config = self.model.config
# create dummy input to get the size of the last hidden state
if "use_flash_attention_2" in kwargs:
self.model = self.model.cuda()
dummy_input = torch.zeros((1, 1), dtype=torch.long).to(self.model.device)
out = self.model(dummy_input)
self.last_hidden_state_size = out.last_hidden_state.shape[-1]
self.model = self.model.cpu()
# print("self.last_hidden_state_size: ",self.last_hidden_state_size)
def resize_token_embeddings(self, *args, **kwargs):
"""
Resize the token embeddings of the model.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
The resized token embeddings.
"""
return self.model.resize_token_embeddings(*args, **kwargs)
"""
Critic model
"""
from typing import Optional
import torch
import torch.nn as nn
from coati.models import BaseModel
from transformers import PretrainedConfig
class Critic(BaseModel):
"""
Critic model class.
Args:
pretrained (str): path to pretrained model.
config (PretrainedConfig): PretrainedConfig used to initiate the base model.
"""
def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
super().__init__(pretrained=pretrained, config=config, **kwargs)
# et last hidden state size with dummy input
self.value_head = nn.Linear(self.last_hidden_state_size, 1)
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(input_ids, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), :].type(
self.value_head.weight.dtype
)
values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length)
return values
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizer
from .base import Actor
try:
from transformers.generation_logits_process import (
LogitsProcessorList,
......@@ -20,6 +18,18 @@ except ImportError:
def _prepare_logits_processor(
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> LogitsProcessorList:
"""
Prepare the logits processor list based on the given parameters.
Args:
top_k (Optional[int]): The number of highest probability logits to keep for each token.
top_p (Optional[float]): The cumulative probability threshold for selecting tokens.
temperature (Optional[float]): The temperature value to apply to the logits.
Returns:
LogitsProcessorList: The list of logits processors.
"""
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
......@@ -31,6 +41,15 @@ def _prepare_logits_processor(
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
"""
Check if the sequence generation is finished.
Args:
unfinished_sequences (torch.Tensor): Tensor indicating the unfinished sequences.
Returns:
bool: True if all sequences are finished, False otherwise.
"""
if dist.is_initialized() and dist.get_world_size() > 1:
# consider DP
unfinished_sequences = unfinished_sequences.clone()
......@@ -38,31 +57,275 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
"""
Update the model keyword arguments based on the outputs and new mask.
Args:
outputs (dict): The outputs from the model.
new_mask: The new attention mask.
**model_kwargs: Additional model keyword arguments.
Returns:
dict: The updated model keyword arguments.
"""
if "past_key_values" in outputs:
model_kwargs["past_key_values"] = outputs["past_key_values"]
else:
model_kwargs["past_key_values"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat([attention_mask, new_mask], dim=-1)
return model_kwargs
def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
model_kwargs["input_ids"] = input_ids
return model_kwargs
def _sample(
model: Actor,
model: Any,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = True,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
stop_token_ids: Optional[List[int]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
max_new_tokens: int = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
stream_interval: int = 2,
**model_kwargs,
) -> torch.Tensor:
"""
Generates new tokens using the given model and input_ids.
Args:
model (Any): The model used for token generation.
input_ids (torch.Tensor): The input tensor containing the initial tokens.
max_length (int): The maximum length of the generated tokens.
early_stopping (bool, optional): Whether to stop generating tokens early if all sequences are finished. Defaults to True.
eos_token_id (int, optional): The ID of the end-of-sequence token. Defaults to None.
pad_token_id (int, optional): The ID of the padding token. Defaults to None.
stop_token_ids (List[int], optional): A list of token IDs that, if encountered, will stop the generation process. Defaults to None.
top_k (int, optional): The number of top-k tokens to consider during sampling. Defaults to None.
top_p (float, optional): The cumulative probability threshold for top-p sampling. Defaults to None.
temperature (float, optional): The temperature value for token sampling. Defaults to None.
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to None.
prepare_inputs_fn (Callable[[torch.Tensor, Any], dict], optional): A function to prepare the model inputs. Defaults to None.
update_model_kwargs_fn (Callable[[dict, Any], dict], optional): A function to update the model kwargs. Defaults to None.
stream_interval (int, optional): The interval for streaming generation. Defaults to 2.
**model_kwargs: Additional keyword arguments for the model.
Returns:
torch.Tensor: The tensor containing the generated tokens.
"""
context_length = input_ids.size(1)
if max_new_tokens is None:
max_new_tokens = max_length - context_length
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
return input_ids
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
past = None
for i in range(context_length, context_length + max_new_tokens):
# Calculate attention mask
if "attention_mask" not in model_kwargs:
model_kwargs["attention_mask"] = input_ids.ne(pad_token_id)
model_inputs = (
prepare_inputs_fn(input_ids, past=past, **model_kwargs)
if prepare_inputs_fn is not None
else {"input_ids": input_ids, "attention_mask": input_ids.ne(pad_token_id)}
)
outputs = model(**model_inputs)
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
# NOTE: this is correct only in left padding mode
next_token_logits = outputs["logits"][:, -1, :]
next_token_logits = logits_processor(input_ids, next_token_logits)
# Sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# Finished sentences should have their next token be a padding token
if eos_token_id is not None:
assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# Update generated ids, model inputs for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
# If eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
if i == context_length + max_new_tokens - 1:
# Force to end with stop token ids
input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
torch.LongTensor(stop_token_ids).to(input_ids.device).long()
)
return input_ids
@torch.inference_mode()
def generate(
model: Any,
input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizer,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = True,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
model (nn.Module): model
input_ids (torch.Tensor): input sequence
max_length (int): max length of the returned sequence
num_beams (int, optional): number of beams. Defaults to 1.
do_sample (bool, optional): whether to do sample. Defaults to True.
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
if is_greedy_gen_mode:
raise NotImplementedError
elif is_sample_gen_mode:
# Run sample
res = _sample(
model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
)
return res
elif is_beam_gen_mode:
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")
def _sample_streaming(
model: Any,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
stop_token_ids: Optional[List[int]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
max_new_tokens: int = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
stream_interval: int = 2,
**model_kwargs,
) -> torch.Tensor:
if input_ids.size(1) >= max_length:
"""
Generates new tokens using a streaming approach.
Args:
model (Any): The model used for token generation.
input_ids (torch.Tensor): The input tensor containing the initial tokens.
max_length (int): The maximum length of the generated sequence.
early_stopping (bool, optional): Whether to stop generating tokens for a sequence if it is finished. Defaults to False.
eos_token_id (int, optional): The ID of the end-of-sequence token. Defaults to None.
pad_token_id (int, optional): The ID of the padding token. Defaults to None.
stop_token_ids (List[int], optional): A list of token IDs that, if encountered, will mark the sequence as finished. Defaults to None.
top_k (int, optional): The number of top-k tokens to consider during sampling. Defaults to None.
top_p (float, optional): The cumulative probability threshold for top-p sampling. Defaults to None.
temperature (float, optional): The temperature value for sampling. Defaults to None.
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to None.
prepare_inputs_fn (Callable[[torch.Tensor, Any], dict], optional): A function to prepare the model inputs. Defaults to None.
update_model_kwargs_fn (Callable[[dict, Any], dict], optional): A function to update the model keyword arguments. Defaults to None.
stream_interval (int, optional): The interval at which to yield the generated tokens. Defaults to 2.
**model_kwargs: Additional keyword arguments to be passed to the model.
Yields:
torch.Tensor: The generated tokens at each step.
Returns:
torch.Tensor: The final generated tokens.
"""
context_length = input_ids.size(1)
if max_new_tokens is None:
max_new_tokens = max_length - context_length
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
return input_ids
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
past = None
for i in range(context_length, context_length + max_new_tokens):
# calculate attention mask
if "attention_mask" not in model_kwargs:
model_kwargs["attention_mask"] = input_ids.ne(pad_token_id)
model_inputs = (
prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
prepare_inputs_fn(input_ids, past=past, **model_kwargs)
if prepare_inputs_fn is not None
else {"input_ids": input_ids, "attention_mask": input_ids.ne(pad_token_id)}
)
outputs = model(**model_inputs)
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
# NOTE: this is correct only in left padding mode
next_token_logits = outputs["logits"][:, -1, :]
......@@ -78,6 +341,7 @@ def _sample(
# update generated ids, model inputs for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
......@@ -85,16 +349,27 @@ def _sample(
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished if early_stopping=True
if early_stopping and _is_sequence_finished(unfinished_sequences):
break
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
return input_ids
# Stop when each sentence is finished if early_stopping=True
if (
(early_stopping and _is_sequence_finished(unfinished_sequences))
or (i - context_length) % stream_interval == 0
or i == context_length + max_new_tokens - 1
):
yield input_ids
if early_stopping and _is_sequence_finished(unfinished_sequences):
break
@torch.no_grad()
def generate(
model: Actor,
@torch.inference_mode()
def generate_streaming(
model: Any,
input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizer,
max_length: int,
......@@ -107,7 +382,7 @@ def generate(
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
):
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
......@@ -132,7 +407,7 @@ def generate(
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
return _sample(
for res in _sample_streaming(
model,
input_ids,
max_length,
......@@ -145,7 +420,8 @@ def generate(
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
)
):
yield res
elif is_beam_gen_mode:
raise NotImplementedError
else:
......
"""
LORA utils
"""
import dataclasses
import math
import warnings
......@@ -8,6 +12,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
@dataclasses.dataclass
class LoRAManager:
......@@ -58,6 +66,10 @@ class LoraLinear(lora.LoRALayer, nn.Module):
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
"""
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
"""
def T(w):
return w.T if self.fan_in_fan_out else w
......@@ -101,6 +113,16 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
"""
Wraps a linear layer with LoRA functionality.
Args:
linear (nn.Linear): The linear layer to be wrapped.
lora_rank (int): The rank of the LoRA decomposition.
Returns:
LoraLinear: The wrapped linear layer with LoRA functionality.
"""
assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
......@@ -109,6 +131,16 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
"""
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
Args:
module (nn.Module): The module to convert to LoRA form.
lora_rank (int): The rank of the LoRA approximation.
Returns:
None
"""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
......@@ -131,23 +163,3 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
_convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module
class LoRAModule(nn.Module):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This class will convert all torch.nn.Linear layer to LoraLinear layer.
Args:
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
lora_train_bias (str, optional): Whether LoRA train biases.
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
Defaults to 'none'.
"""
def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
def convert_to_lora(self) -> None:
convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment