Unverified Commit da4f7b85 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[chat] fix bugs and add unit tests (#4213)

* style: rename replay buffer

Experience replay is typically for off policy algorithms.
Use this name in PPO maybe misleading.

* fix: fix wrong zero2 default arg

* test: update experience tests

* style: rename zero_pad fn

* fix: defer init in CycledDataLoader

* test: add benchmark test

* style: rename internal fn of generation

* style: rename internal fn of lora

* fix: remove unused loss fn

* fix: remove unused utils fn

* refactor: remove generate_with_actor fn

* fix: fix type annotation

* test: add models tests

* fix: skip llama due to long execution time

* style: modify dataset

* style: apply formatter

* perf: update reward dataset

* fix: fix wrong IGNORE_INDEX in sft dataset

* fix: remove DataCollatorForSupervisedDataset

* test: add dataset tests

* style: apply formatter

* style: rename test_ci to test_train

* feat: add llama in inference

* test: add inference tests

* test: change test scripts directory

* fix: update ci

* fix: fix typo

* fix: skip llama due to oom

* fix: fix file mod

* style: apply formatter

* refactor: remove duplicated llama_gptq

* style: apply formatter

* to: update rm test

* feat: add tokenizer arg

* feat: add download model script

* test: update train tests

* fix: modify gemini load and save pretrained

* test: update checkpoint io test

* to: modify nproc_per_node

* fix: do not remove existing dir

* fix: modify save path

* test: add random choice

* fix: fix sft path

* fix: enlarge nproc_per_node to avoid oom

* fix: add num_retry

* fix: make lora config of rm and critic consistent

* fix: add warning about lora weights

* fix: skip some gpt2 tests

* fix: remove grad ckpt in rm and critic due to errors

* refactor: directly use Actor in train_sft

* test: add more arguments

* fix: disable grad ckpt when using lora

* fix: fix save_pretrained and related tests

* test: enable zero2 tests

* revert: remove useless fn

* style: polish code

* test: modify test args
parent 16bf4c02
...@@ -43,7 +43,9 @@ jobs: ...@@ -43,7 +43,9 @@ jobs:
run: | run: |
cd applications/Chat cd applications/Chat
rm -rf ~/.cache/colossalai rm -rf ~/.cache/colossalai
./examples/test_ci.sh ./tests/test_inference.sh
./tests/test_benchmarks.sh
./tests/test_train.sh
env: env:
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
MAX_JOBS: 8 MAX_JOBS: 8
......
from .prompt_dataset import PromptDataset from .prompt_dataset import PromptDataset
from .reward_dataset import HhRlhfDataset, RmStaticDataset from .reward_dataset import HhRlhfDataset, RmStaticDataset
from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0 from .utils import is_rank_0
__all__ = [ __all__ = [
'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset', 'RmStaticDataset', 'HhRlhfDataset',
'DataCollatorForSupervisedDataset', 'PromptDataset' 'SFTDataset', 'SupervisedDataset',
'PromptDataset', 'is_rank_0',
] ]
import copy
import random
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from typing import Dict
from typing import Callable, Dict, Sequence
import torch import torch
import torch.distributed as dist
import transformers import transformers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload from .utils import jload
logger = get_dist_logger()
class PromptDataset(Dataset): class PromptDataset(Dataset):
...@@ -27,12 +20,13 @@ class PromptDataset(Dataset): ...@@ -27,12 +20,13 @@ class PromptDataset(Dataset):
max_length: int = 96): max_length: int = 96):
super(PromptDataset, self).__init__() super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list) self.keyed_prompt = defaultdict(list)
logger.info("Loading data...") self.logger = get_dist_logger()
self.logger.info("Loading data...")
list_data_dict = jload(data_path) list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.") self.logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_datasets_size is not None: if max_datasets_size is not None:
logger.info(f"Limiting dataset to {max_datasets_size} examples.") self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size] list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict] instructions = [data_dict["instruction"] for data_dict in list_data_dict]
......
...@@ -20,44 +20,44 @@ class RmStaticDataset(Dataset): ...@@ -20,44 +20,44 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__() super().__init__()
self.chosen = [] self.end_token = tokenizer.eos_token \
self.reject = [] if special_token is None else special_token
if special_token is None:
self.end_token = tokenizer.eos_token chosen = [
else: data["prompt"] + data["chosen"] + self.end_token
self.end_token = special_token for data in tqdm(dataset, disable=not is_rank_0())
for data in tqdm(dataset, disable=not is_rank_0()): ]
prompt = data['prompt'] chosen_token = tokenizer(chosen,
max_length=max_length,
chosen = prompt + data['chosen'] + self.end_token padding="max_length",
chosen_token = tokenizer(chosen, truncation=True,
max_length=max_length, return_tensors="pt")
padding="max_length", self.chosen = {
truncation=True, "input_ids": chosen_token["input_ids"],
return_tensors="pt") "attention_mask": chosen_token["attention_mask"]
self.chosen.append({ }
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask'] reject = [
}) data["prompt"] + data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
reject = prompt + data['rejected'] + self.end_token ]
reject_token = tokenizer(reject, reject_token = tokenizer(reject,
max_length=max_length, max_length=max_length,
padding="max_length", padding="max_length",
truncation=True, truncation=True,
return_tensors="pt") return_tensors="pt")
self.reject.append({ self.reject = {
"input_ids": reject_token['input_ids'], "input_ids": reject_token["input_ids"],
"attention_mask": reject_token['attention_mask'] "attention_mask": reject_token["attention_mask"]
}) }
def __len__(self): def __len__(self):
length = len(self.chosen) length = self.chosen["input_ids"].shape[0]
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
"input_ids"], self.reject[idx]["attention_mask"] self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
# Anthropic/hh-rlhf # Anthropic/hh-rlhf
...@@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset): ...@@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__() super().__init__()
self.chosen = [] self.end_token = tokenizer.eos_token \
self.reject = [] if special_token is None else special_token
if special_token is None:
self.end_token = tokenizer.eos_token chosen = [
else: data["chosen"] + self.end_token
self.end_token = special_token for data in tqdm(dataset, disable=not is_rank_0())
for data in tqdm(dataset, disable=not is_rank_0()): ]
chosen = data['chosen'] + self.end_token chosen_token = tokenizer(chosen,
chosen_token = tokenizer(chosen, max_length=max_length,
max_length=max_length, padding="max_length",
padding="max_length", truncation=True,
truncation=True, return_tensors="pt")
return_tensors="pt") self.chosen = {
self.chosen.append({ "input_ids": chosen_token["input_ids"],
"input_ids": chosen_token['input_ids'], "attention_mask": chosen_token["attention_mask"]
"attention_mask": chosen_token['attention_mask'] }
})
reject = [
reject = data['rejected'] + self.end_token data["rejected"] + self.end_token
reject_token = tokenizer(reject, for data in tqdm(dataset, disable=not is_rank_0())
max_length=max_length, ]
padding="max_length", reject_token = tokenizer(reject,
truncation=True, max_length=max_length,
return_tensors="pt") padding="max_length",
self.reject.append({ truncation=True,
"input_ids": reject_token['input_ids'], return_tensors="pt")
"attention_mask": reject_token['attention_mask'] self.reject = {
}) "input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self): def __len__(self):
length = len(self.chosen) length = self.chosen["input_ids"].shape[0]
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
"input_ids"], self.reject[idx]["attention_mask"] self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
...@@ -13,44 +13,64 @@ ...@@ -13,44 +13,64 @@
# limitations under the License. # limitations under the License.
import copy import copy
import random from typing import Dict, Sequence, Tuple
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Sequence, Tuple
import torch import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .conversation import default_conversation
from .utils import is_rank_0, jload from .utils import is_rank_0, jload
# The following is a template prompt for a 4-round conversation.
"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
"""
# Please note that we only calculate loss on assistant's answer tokens.
logger = get_dist_logger() logger = get_dist_logger()
IGNORE_INDEX = -100 IGNORE_INDEX = -100
DEFAULT_EOS_TOKEN = "</s>"
PROMPT_DICT = { PROMPT_DICT = {
"prompt_input": "prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
("Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n"
"Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
"prompt_no_input": ("Below is an instruction that describes a task. " "prompt_no_input": ("Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n" "Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"), "### Instruction:\n{instruction}\n\n### Response:"),
} }
def _preprocess(sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)]
sequences_token = tokenizer(sequences,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
sources_token = tokenizer(sources,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item()
pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
class SFTDataset(Dataset): class SFTDataset(Dataset):
""" """
Dataset for sft model Dataset for sft model
...@@ -61,115 +81,31 @@ class SFTDataset(Dataset): ...@@ -61,115 +81,31 @@ class SFTDataset(Dataset):
max_length: max length of input max_length: max length of input
""" """
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None: def __init__(self,
dataset: Dict,
tokenizer: PreTrainedTokenizer,
max_length: int = 512
) -> None:
super().__init__() super().__init__()
self.input_ids = [] self.input_ids = []
for data in tqdm(dataset, disable=not is_rank_0()): sources = [data["prompt"] for data in dataset]
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token targets = [
prompt_token = tokenizer(prompt, data["completion"] + tokenizer.eos_token
max_length=max_length, for data in tqdm(dataset, disable=not is_rank_0())
padding="max_length", ]
truncation=True,
return_tensors="pt")
self.input_ids.append(prompt_token['input_ids'][0]) self.input_ids, self.labels, self.attention_mask = \
self.labels = copy.deepcopy(self.input_ids) _preprocess(sources, targets, tokenizer, max_length)
def __len__(self): def __len__(self):
length = len(self.input_ids) length = self.input_ids.shape[0]
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict:
"""Preprocess the conversation data by tokenizing."""
conversations = []
intermediates = []
for source in sources:
header = f"{default_conversation.system}"
conversation, intermediate = _add_speaker_and_signal(header, source)
conversations.append(conversation)
intermediates.append(intermediate)
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
assert len(targets) == len(intermediates)
for target, inters in zip(targets, intermediates):
mask = torch.zeros_like(target, dtype=torch.bool)
for inter in inters:
tokenized = _tokenize_fn(inter, tokenizer, max_length)
start_idx = tokenized["input_ids"][0].size(0) - 1
end_idx = tokenized["input_ids"][1].size(0)
mask[start_idx:end_idx] = True
target[~mask] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def _add_speaker_and_signal(header: str,
source: List[Dict],
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
END_SIGNAL = DEFAULT_EOS_TOKEN
conversation = header
intermediate = []
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = default_conversation.roles[1]
else:
from_str = 'unknown'
value = from_str + ": " + sentence["value"] + END_SIGNAL
if sentence["from"].lower() == "gpt":
start = conversation + from_str + ": "
end = conversation + value
intermediate.append([start, end])
if get_conversation:
conversation += value
return conversation, intermediate
class SupervisedDataset(Dataset): class SupervisedDataset(Dataset):
...@@ -177,10 +113,10 @@ class SupervisedDataset(Dataset): ...@@ -177,10 +113,10 @@ class SupervisedDataset(Dataset):
def __init__(self, def __init__(self,
data_path: str, data_path: str,
tokenizer: transformers.PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
max_datasets_size: int = None, max_datasets_size: int = None,
max_length: int = 512): max_length: int = 512):
super(SupervisedDataset, self).__init__() super().__init__()
logger.info("Loading data...") logger.info("Loading data...")
list_data_dict = jload(data_path) list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.") logger.info(f"Loaded {len(list_data_dict)} examples.")
...@@ -190,52 +126,25 @@ class SupervisedDataset(Dataset): ...@@ -190,52 +126,25 @@ class SupervisedDataset(Dataset):
list_data_dict = list_data_dict[:max_datasets_size] list_data_dict = list_data_dict[:max_datasets_size]
logger.info("Formatting inputs...") logger.info("Formatting inputs...")
if "conversations" not in list_data_dict[0]: prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] sources = [
sources = [ prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
prompt_input.format_map(example) for example in list_data_dict
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict ]
] targets = [
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] example['output'] + tokenizer.eos_token
for example in list_data_dict
if is_rank_0(): ]
logger.info("Tokenizing inputs... This may take some time...")
logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length) self.input_ids, self.labels, self.attention_mask = \
else: _preprocess(sources, targets, tokenizer, max_length)
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")
sources = [conv["conversations"] for conv in list_data_dict]
data_dict = preprocess_conversation(sources, tokenizer, max_length)
if is_rank_0():
logger.info("Tokenizing finish.")
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
def __len__(self): def __len__(self):
return len(self.input_ids) length = self.input_ids.shape[0]
return length
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: def __getitem__(self, idx):
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) return dict(input_ids=self.input_ids[idx],
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, labels=self.labels[idx],
batch_first=True, attention_mask=self.attention_mask[idx])
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
from .base import ExperienceBuffer
from .naive import NaiveExperienceBuffer
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
...@@ -4,8 +4,8 @@ from typing import Any ...@@ -4,8 +4,8 @@ from typing import Any
from coati.experience_maker.base import Experience from coati.experience_maker.base import Experience
class ReplayBuffer(ABC): class ExperienceBuffer(ABC):
"""Replay buffer base class. It stores experience. """Experience buffer base class. It stores experience.
Args: Args:
sample_batch_size (int): Batch size when sampling. sample_batch_size (int): Batch size when sampling.
......
...@@ -4,12 +4,12 @@ from typing import List ...@@ -4,12 +4,12 @@ from typing import List
import torch import torch
from coati.experience_maker.base import Experience from coati.experience_maker.base import Experience
from .base import ReplayBuffer from .base import ExperienceBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch from .utils import BufferItem, make_experience_batch, split_experience_batch
class NaiveReplayBuffer(ReplayBuffer): class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive replay buffer class. It stores experience. """Naive experience buffer class. It stores experience.
Args: Args:
sample_batch_size (int): Batch size when sampling. sample_batch_size (int): Batch size when sampling.
......
...@@ -33,7 +33,8 @@ class BufferItem: ...@@ -33,7 +33,8 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]: def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0) batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)] batch_kwargs = [{} for _ in range(batch_size)]
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') keys = ('sequences', 'action_log_probs', 'values',
'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys: for key in keys:
value = getattr(experience, key) value = getattr(experience, key)
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
...@@ -48,7 +49,7 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]: ...@@ -48,7 +49,7 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
return items return items
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
assert side in ('left', 'right') assert side in ('left', 'right')
max_len = max(seq.size(0) for seq in sequences) max_len = max(seq.size(0) for seq in sequences)
padded_sequences = [] padded_sequences = []
...@@ -62,11 +63,12 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> tor ...@@ -62,11 +63,12 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> tor
def make_experience_batch(items: List[BufferItem]) -> Experience: def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {} kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask')) to_pad_keys = set(('action_log_probs', 'action_mask'))
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') keys = ('sequences', 'action_log_probs', 'values',
'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys: for key in keys:
vals = [getattr(item, key) for item in items] vals = [getattr(item, key) for item in items]
if key in to_pad_keys: if key in to_pad_keys:
batch_data = zero_pad_sequences(vals) batch_data = _zero_pad_sequences(vals)
else: else:
batch_data = torch.stack(vals, dim=0) batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data kwargs[key] = batch_data
......
import torch import torch
from coati.models.generation import generate_with_actor import torch.nn.functional as F
from coati.models.utils import calc_action_log_probs, compute_reward, normalize from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from .base import Experience, ExperienceMaker from .base import Experience, ExperienceMaker
...@@ -17,10 +18,27 @@ class NaiveExperienceMaker(ExperienceMaker): ...@@ -17,10 +18,27 @@ class NaiveExperienceMaker(ExperienceMaker):
self.initial_model.eval() self.initial_model.eval()
self.reward_model.eval() self.reward_model.eval()
sequences, attention_mask, action_mask = generate_with_actor(self.actor, # generate sequences
input_ids, sequences = generate(self.actor, input_ids, **generate_kwargs)
return_action_mask=True,
**generate_kwargs) # calculate auxiliary tensors
attention_mask = None
pad_token_id = generate_kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id)\
.to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
eos_token_id = generate_kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len):]
num_actions = action_mask.size(1) num_actions = action_mask.size(1)
actor_output = self.actor(sequences, attention_mask) actor_output = self.actor(sequences, attention_mask)
......
from .base import Actor, Critic, RewardModel from .base import Actor, Critic, RewardModel
from .lora import LoRAModule, convert_to_lora_module from .lora import LoRAModule, convert_to_lora_module
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [ __all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss', 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
'LoRAModule', 'convert_to_lora_module' 'LoRAModule', 'convert_to_lora_module'
] ]
...@@ -14,7 +14,6 @@ class BLOOMCritic(Critic): ...@@ -14,7 +14,6 @@ class BLOOMCritic(Critic):
Args: Args:
pretrained (str): Pretrained model name or path. pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config. config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank. lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
...@@ -22,7 +21,6 @@ class BLOOMCritic(Critic): ...@@ -22,7 +21,6 @@ class BLOOMCritic(Critic):
def __init__(self, def __init__(self,
pretrained: str = None, pretrained: str = None,
config: Optional[BloomConfig] = None, config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none', lora_train_bias: str = 'none',
**kwargs) -> None: **kwargs) -> None:
...@@ -32,7 +30,6 @@ class BLOOMCritic(Critic): ...@@ -32,7 +30,6 @@ class BLOOMCritic(Critic):
model = BloomModel(config) model = BloomModel(config)
else: else:
model = BloomModel(BloomConfig()) model = BloomModel(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1) value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
...@@ -13,7 +13,6 @@ class BLOOMRM(RewardModel): ...@@ -13,7 +13,6 @@ class BLOOMRM(RewardModel):
Args: Args:
pretrained (str): Pretrained model name or path. pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config. config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank. lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
...@@ -21,7 +20,6 @@ class BLOOMRM(RewardModel): ...@@ -21,7 +20,6 @@ class BLOOMRM(RewardModel):
def __init__(self, def __init__(self,
pretrained: str = None, pretrained: str = None,
config: Optional[BloomConfig] = None, config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = 'none') -> None:
if pretrained is not None: if pretrained is not None:
...@@ -30,8 +28,7 @@ class BLOOMRM(RewardModel): ...@@ -30,8 +28,7 @@ class BLOOMRM(RewardModel):
model = BloomModel(config) model = BloomModel(config)
else: else:
model = BloomModel(BloomConfig()) model = BloomModel(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1) value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias) super().__init__(model, value_head, lora_rank, lora_train_bias)
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F from .base import Actor
try: try:
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
...@@ -16,9 +16,9 @@ except ImportError: ...@@ -16,9 +16,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(top_k: Optional[int] = None, def _prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList: temperature: Optional[float] = None) -> LogitsProcessorList:
processor_list = LogitsProcessorList() processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0: if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature)) processor_list.append(TemperatureLogitsWarper(temperature))
...@@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: ...@@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0 return unfinished_sequences.max() == 0
def sample(model: nn.Module, def _sample(model: Actor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
max_length: int, max_length: int,
early_stopping: bool = False, early_stopping: bool = False,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor: **model_kwargs) -> torch.Tensor:
if input_ids.size(1) >= max_length: if input_ids.size(1) >= max_length:
return input_ids return input_ids
logits_processor = prepare_logits_processor(top_k, top_p, temperature) logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length): for _ in range(input_ids.size(1), max_length):
...@@ -89,7 +89,8 @@ def sample(model: nn.Module, ...@@ -89,7 +89,8 @@ def sample(model: nn.Module,
return input_ids return input_ids
def generate(model: nn.Module, @torch.no_grad()
def generate(model: Actor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
max_length: int, max_length: int,
num_beams: int = 1, num_beams: int = 1,
...@@ -128,51 +129,19 @@ def generate(model: nn.Module, ...@@ -128,51 +129,19 @@ def generate(model: nn.Module,
raise NotImplementedError raise NotImplementedError
elif is_sample_gen_mode: elif is_sample_gen_mode:
# run sample # run sample
return sample(model, return _sample(model,
input_ids, input_ids,
max_length, max_length,
early_stopping=early_stopping, early_stopping=early_stopping,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn, prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn, update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs) **model_kwargs)
elif is_beam_gen_mode: elif is_beam_gen_mode:
raise NotImplementedError raise NotImplementedError
else: else:
raise ValueError("Unsupported generation mode") raise ValueError("Unsupported generation mode")
@torch.no_grad()
def generate_with_actor(
actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
sequences = generate(actor_model, input_ids, **kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
...@@ -14,7 +14,6 @@ class GPTCritic(Critic): ...@@ -14,7 +14,6 @@ class GPTCritic(Critic):
Args: Args:
pretrained (str): Pretrained model name or path. pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config. config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition. lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
...@@ -22,7 +21,6 @@ class GPTCritic(Critic): ...@@ -22,7 +21,6 @@ class GPTCritic(Critic):
def __init__(self, def __init__(self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None, config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none', lora_train_bias: str = 'none',
**kwargs) -> None: **kwargs) -> None:
...@@ -32,7 +30,6 @@ class GPTCritic(Critic): ...@@ -32,7 +30,6 @@ class GPTCritic(Critic):
model = GPT2Model(config) model = GPT2Model(config)
else: else:
model = GPT2Model(GPT2Config()) model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1) value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
...@@ -14,7 +14,6 @@ class GPTRM(RewardModel): ...@@ -14,7 +14,6 @@ class GPTRM(RewardModel):
Args: Args:
pretrained (str): Pretrained model name or path. pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config. config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation. lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
...@@ -22,7 +21,6 @@ class GPTRM(RewardModel): ...@@ -22,7 +21,6 @@ class GPTRM(RewardModel):
def __init__(self, def __init__(self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None, config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = 'none') -> None:
if pretrained is not None: if pretrained is not None:
...@@ -31,8 +29,6 @@ class GPTRM(RewardModel): ...@@ -31,8 +29,6 @@ class GPTRM(RewardModel):
model = GPT2Model(config) model = GPT2Model(config)
else: else:
model = GPT2Model(GPT2Config()) model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1) value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
......
...@@ -13,7 +13,6 @@ class LlamaCritic(Critic): ...@@ -13,7 +13,6 @@ class LlamaCritic(Critic):
Args: Args:
pretrained (str): Pretrained model name or path. pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config. config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank. lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
...@@ -21,7 +20,6 @@ class LlamaCritic(Critic): ...@@ -21,7 +20,6 @@ class LlamaCritic(Critic):
def __init__(self, def __init__(self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None, config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none', lora_train_bias: str = 'none',
**kwargs) -> None: **kwargs) -> None:
...@@ -33,9 +31,5 @@ class LlamaCritic(Critic): ...@@ -33,9 +31,5 @@ class LlamaCritic(Critic):
else: else:
model = LlamaModel(LlamaConfig()) model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1) value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
...@@ -13,7 +13,6 @@ class LlamaRM(RewardModel): ...@@ -13,7 +13,6 @@ class LlamaRM(RewardModel):
Args: Args:
pretrained (str): Pretrained model name or path. pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config. config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank. lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
...@@ -21,7 +20,6 @@ class LlamaRM(RewardModel): ...@@ -21,7 +20,6 @@ class LlamaRM(RewardModel):
def __init__(self, def __init__(self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None, config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = 'none') -> None:
...@@ -32,8 +30,6 @@ class LlamaRM(RewardModel): ...@@ -32,8 +30,6 @@ class LlamaRM(RewardModel):
else: else:
model = LlamaModel(LlamaConfig()) model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1) value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
......
...@@ -98,18 +98,18 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -98,18 +98,18 @@ class LoraLinear(lora.LoRALayer, nn.Module):
return F.linear(x, T(self.weight), bias=self.bias) return F.linear(x, T(self.weight), bias=self.bias)
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear return lora_linear
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children(): for name, child in module.named_children():
if isinstance(child, nn.Linear): if isinstance(child, nn.Linear):
setattr(module, name, lora_linear_wrapper(child, lora_rank)) setattr(module, name, _lora_linear_wrapper(child, lora_rank))
else: else:
convert_to_lora_recursively(child, lora_rank) _convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module: def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
...@@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s ...@@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
""" """
if lora_rank <= 0: if lora_rank <= 0:
return module return module
convert_to_lora_recursively(module, lora_rank) _convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias) lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module return module
......
...@@ -68,31 +68,6 @@ class ValueLoss(nn.Module): ...@@ -68,31 +68,6 @@ class ValueLoss(nn.Module):
return 0.5 * loss return 0.5 * loss
class PPOPtxActorLoss(nn.Module):
"""
To Do:
PPO-ptx Actor Loss
"""
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
super().__init__()
self.pretrain_coef = pretrain_coef
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
self.pretrain_loss_fn = pretrain_loss_fn
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
lm_logits: torch.Tensor,
lm_input_ids: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
return policy_loss + self.pretrain_coef * lm_loss
class LogSigLoss(nn.Module): class LogSigLoss(nn.Module):
""" """
Pairwise Loss for Reward Model Pairwise Loss for Reward Model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment