Unverified Commit da4f7b85 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[chat] fix bugs and add unit tests (#4213)

* style: rename replay buffer

Experience replay is typically for off policy algorithms.
Use this name in PPO maybe misleading.

* fix: fix wrong zero2 default arg

* test: update experience tests

* style: rename zero_pad fn

* fix: defer init in CycledDataLoader

* test: add benchmark test

* style: rename internal fn of generation

* style: rename internal fn of lora

* fix: remove unused loss fn

* fix: remove unused utils fn

* refactor: remove generate_with_actor fn

* fix: fix type annotation

* test: add models tests

* fix: skip llama due to long execution time

* style: modify dataset

* style: apply formatter

* perf: update reward dataset

* fix: fix wrong IGNORE_INDEX in sft dataset

* fix: remove DataCollatorForSupervisedDataset

* test: add dataset tests

* style: apply formatter

* style: rename test_ci to test_train

* feat: add llama in inference

* test: add inference tests

* test: change test scripts directory

* fix: update ci

* fix: fix typo

* fix: skip llama due to oom

* fix: fix file mod

* style: apply formatter

* refactor: remove duplicated llama_gptq

* style: apply formatter

* to: update rm test

* feat: add tokenizer arg

* feat: add download model script

* test: update train tests

* fix: modify gemini load and save pretrained

* test: update checkpoint io test

* to: modify nproc_per_node

* fix: do not remove existing dir

* fix: modify save path

* test: add random choice

* fix: fix sft path

* fix: enlarge nproc_per_node to avoid oom

* fix: add num_retry

* fix: make lora config of rm and critic consistent

* fix: add warning about lora weights

* fix: skip some gpt2 tests

* fix: remove grad ckpt in rm and critic due to errors

* refactor: directly use Actor in train_sft

* test: add more arguments

* fix: disable grad ckpt when using lora

* fix: fix save_pretrained and related tests

* test: enable zero2 tests

* revert: remove useless fn

* style: polish code

* test: modify test args
parent 16bf4c02
......@@ -43,7 +43,9 @@ jobs:
run: |
cd applications/Chat
rm -rf ~/.cache/colossalai
./examples/test_ci.sh
./tests/test_inference.sh
./tests/test_benchmarks.sh
./tests/test_train.sh
env:
NCCL_SHM_DISABLE: 1
MAX_JOBS: 8
......
from .prompt_dataset import PromptDataset
from .reward_dataset import HhRlhfDataset, RmStaticDataset
from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0
__all__ = [
'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset',
'DataCollatorForSupervisedDataset', 'PromptDataset'
'RmStaticDataset', 'HhRlhfDataset',
'SFTDataset', 'SupervisedDataset',
'PromptDataset', 'is_rank_0',
]
import copy
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
from typing import Dict
import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
logger = get_dist_logger()
from .utils import jload
class PromptDataset(Dataset):
......@@ -27,12 +20,13 @@ class PromptDataset(Dataset):
max_length: int = 96):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
logger.info("Loading data...")
self.logger = get_dist_logger()
self.logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
self.logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_datasets_size is not None:
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
......
......@@ -20,44 +20,44 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt']
chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
self.end_token = tokenizer.eos_token \
if special_token is None else special_token
chosen = [
data["prompt"] + data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
reject = [
data["prompt"] + data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self):
length = len(self.chosen)
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
# Anthropic/hh-rlhf
......@@ -74,39 +74,41 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
chosen = data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
reject = data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
self.end_token = tokenizer.eos_token \
if special_token is None else special_token
chosen = [
data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
reject = [
data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self):
length = len(self.chosen)
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
......@@ -13,44 +13,64 @@
# limitations under the License.
import copy
import random
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Sequence, Tuple
from typing import Dict, Sequence, Tuple
import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from colossalai.logging import get_dist_logger
from .conversation import default_conversation
from .utils import is_rank_0, jload
# The following is a template prompt for a 4-round conversation.
"""
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>Human: xxx</s>Assistant: xxx</s>
"""
# Please note that we only calculate loss on assistant's answer tokens.
logger = get_dist_logger()
IGNORE_INDEX = -100
DEFAULT_EOS_TOKEN = "</s>"
PROMPT_DICT = {
"prompt_input":
("Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
"prompt_no_input": ("Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"),
}
def _preprocess(sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)]
sequences_token = tokenizer(sequences,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
sources_token = tokenizer(sources,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item()
pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
class SFTDataset(Dataset):
"""
Dataset for sft model
......@@ -61,115 +81,31 @@ class SFTDataset(Dataset):
max_length: max length of input
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
def __init__(self,
dataset: Dict,
tokenizer: PreTrainedTokenizer,
max_length: int = 512
) -> None:
super().__init__()
self.input_ids = []
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
prompt_token = tokenizer(prompt,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
sources = [data["prompt"] for data in dataset]
targets = [
data["completion"] + tokenizer.eos_token
for data in tqdm(dataset, disable=not is_rank_0())
]
self.input_ids.append(prompt_token['input_ids'][0])
self.labels = copy.deepcopy(self.input_ids)
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = len(self.input_ids)
length = self.input_ids.shape[0]
return length
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
def preprocess_conversation(sources: List[List[Dict]], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict:
"""Preprocess the conversation data by tokenizing."""
conversations = []
intermediates = []
for source in sources:
header = f"{default_conversation.system}"
conversation, intermediate = _add_speaker_and_signal(header, source)
conversations.append(conversation)
intermediates.append(intermediate)
conversations_tokenized = _tokenize_fn(conversations, tokenizer, max_length)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
assert len(targets) == len(intermediates)
for target, inters in zip(targets, intermediates):
mask = torch.zeros_like(target, dtype=torch.bool)
for inter in inters:
tokenized = _tokenize_fn(inter, tokenizer, max_length)
start_idx = tokenized["input_ids"][0].size(0) - 1
end_idx = tokenized["input_ids"][1].size(0)
mask[start_idx:end_idx] = True
target[~mask] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def _add_speaker_and_signal(header: str,
source: List[Dict],
get_conversation: bool = True) -> Tuple[str, List[List[str]]]:
END_SIGNAL = DEFAULT_EOS_TOKEN
conversation = header
intermediate = []
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = default_conversation.roles[1]
else:
from_str = 'unknown'
value = from_str + ": " + sentence["value"] + END_SIGNAL
if sentence["from"].lower() == "gpt":
start = conversation + from_str + ": "
end = conversation + value
intermediate.append([start, end])
if get_conversation:
conversation += value
return conversation, intermediate
return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])
class SupervisedDataset(Dataset):
......@@ -177,10 +113,10 @@ class SupervisedDataset(Dataset):
def __init__(self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
tokenizer: PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 512):
super(SupervisedDataset, self).__init__()
super().__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
......@@ -190,52 +126,25 @@ class SupervisedDataset(Dataset):
list_data_dict = list_data_dict[:max_datasets_size]
logger.info("Formatting inputs...")
if "conversations" not in list_data_dict[0]:
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example)
if example.get("input", "") != "" else prompt_no_input.format_map(example) for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length)
else:
if is_rank_0():
logger.info("Tokenizing inputs... This may take some time...")
sources = [conv["conversations"] for conv in list_data_dict]
data_dict = preprocess_conversation(sources, tokenizer, max_length)
if is_rank_0():
logger.info("Tokenizing finish.")
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [
example['output'] + tokenizer.eos_token
for example in list_data_dict
]
logger.info("Tokenizing inputs... This may take some time...")
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
length = self.input_ids.shape[0]
return length
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])
from .base import ExperienceBuffer
from .naive import NaiveExperienceBuffer
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
......@@ -4,8 +4,8 @@ from typing import Any
from coati.experience_maker.base import Experience
class ReplayBuffer(ABC):
"""Replay buffer base class. It stores experience.
class ExperienceBuffer(ABC):
"""Experience buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
......
......@@ -4,12 +4,12 @@ from typing import List
import torch
from coati.experience_maker.base import Experience
from .base import ReplayBuffer
from .base import ExperienceBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch
class NaiveReplayBuffer(ReplayBuffer):
"""Naive replay buffer class. It stores experience.
class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive experience buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
......
......@@ -33,7 +33,8 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
keys = ('sequences', 'action_log_probs', 'values',
'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
......@@ -48,7 +49,7 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
return items
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
assert side in ('left', 'right')
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
......@@ -62,11 +63,12 @@ def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> tor
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask'))
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
keys = ('sequences', 'action_log_probs', 'values',
'reward', 'advantages', 'attention_mask', 'action_mask')
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
batch_data = zero_pad_sequences(vals)
batch_data = _zero_pad_sequences(vals)
else:
batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data
......
import torch
from coati.models.generation import generate_with_actor
from coati.models.utils import calc_action_log_probs, compute_reward, normalize
import torch.nn.functional as F
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from .base import Experience, ExperienceMaker
......@@ -17,10 +18,27 @@ class NaiveExperienceMaker(ExperienceMaker):
self.initial_model.eval()
self.reward_model.eval()
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
input_ids,
return_action_mask=True,
**generate_kwargs)
# generate sequences
sequences = generate(self.actor, input_ids, **generate_kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = generate_kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id)\
.to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
eos_token_id = generate_kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len):]
num_actions = action_mask.size(1)
actor_output = self.actor(sequences, attention_mask)
......
from .base import Actor, Critic, RewardModel
from .lora import LoRAModule, convert_to_lora_module
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
'LoRAModule', 'convert_to_lora_module'
]
......@@ -14,7 +14,6 @@ class BLOOMCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
......@@ -22,7 +21,6 @@ class BLOOMCritic(Critic):
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
......@@ -32,7 +30,6 @@ class BLOOMCritic(Critic):
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
......@@ -13,7 +13,6 @@ class BLOOMRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
......@@ -21,7 +20,6 @@ class BLOOMRM(RewardModel):
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
......@@ -30,8 +28,7 @@ class BLOOMRM(RewardModel):
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .base import Actor
try:
from transformers.generation_logits_process import (
......@@ -16,9 +16,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
def _prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
......@@ -37,22 +37,22 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
def sample(model: nn.Module,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
def _sample(model: Actor,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
......@@ -89,7 +89,8 @@ def sample(model: nn.Module,
return input_ids
def generate(model: nn.Module,
@torch.no_grad()
def generate(model: Actor,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
......@@ -128,51 +129,19 @@ def generate(model: nn.Module,
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
return sample(model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs)
return _sample(model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs)
elif is_beam_gen_mode:
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")
@torch.no_grad()
def generate_with_actor(
actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
sequences = generate(actor_model, input_ids, **kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
......@@ -14,7 +14,6 @@ class GPTCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
......@@ -22,7 +21,6 @@ class GPTCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
......@@ -32,7 +30,6 @@ class GPTCritic(Critic):
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
......@@ -14,7 +14,6 @@ class GPTRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
......@@ -22,7 +21,6 @@ class GPTRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
......@@ -31,8 +29,6 @@ class GPTRM(RewardModel):
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
......
......@@ -13,7 +13,6 @@ class LlamaCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
......@@ -21,7 +20,6 @@ class LlamaCritic(Critic):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
......@@ -33,9 +31,5 @@ class LlamaCritic(Critic):
else:
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
......@@ -13,7 +13,6 @@ class LlamaRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
......@@ -21,7 +20,6 @@ class LlamaRM(RewardModel):
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
......@@ -32,8 +30,6 @@ class LlamaRM(RewardModel):
else:
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
......
......@@ -98,18 +98,18 @@ class LoraLinear(lora.LoRALayer, nn.Module):
return F.linear(x, T(self.weight), bias=self.bias)
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, lora_linear_wrapper(child, lora_rank))
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
else:
convert_to_lora_recursively(child, lora_rank)
_convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
......@@ -124,7 +124,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
"""
if lora_rank <= 0:
return module
convert_to_lora_recursively(module, lora_rank)
_convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module
......
......@@ -68,31 +68,6 @@ class ValueLoss(nn.Module):
return 0.5 * loss
class PPOPtxActorLoss(nn.Module):
"""
To Do:
PPO-ptx Actor Loss
"""
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
super().__init__()
self.pretrain_coef = pretrain_coef
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
self.pretrain_loss_fn = pretrain_loss_fn
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
lm_logits: torch.Tensor,
lm_input_ids: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
return policy_loss + self.pretrain_coef * lm_loss
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment