Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
import argparse
import os
import socket
from functools import partial
import ray
import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
from coati.ray.utils import (
get_actor_from_args,
get_critic_from_args,
get_receivers_per_sender,
get_reward_model_from_args,
get_strategy_from_args,
)
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_trainers),
"master_port": trainer_port,
"master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
"local_rank": "0",
"rank": "0",
"world_size": "1",
"master_port": maker_port,
"master_addr": master_addr,
}
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
def model_fn():
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = (
get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
)
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = (
llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
kl_coef=0.1,
debug=args.debug,
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
use_cache=True,
)
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = (
get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
.half()
.cuda()
)
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
eval_performance=True,
debug=args.debug,
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
return dataloader
# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
# for trainer_ref in trainer_refs
# ])
wait_tasks = []
wait_tasks.append(
experience_holder_ref.workingloop.remote(
partial(build_dataloader, dataset_size), num_steps=args.experience_steps
)
)
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
"--trainer_strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default="ddp",
)
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
import argparse
import os
import socket
from functools import partial
import ray
import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
from coati.ray.utils import (
get_actor_from_args,
get_critic_from_args,
get_receivers_per_sender,
get_reward_model_from_args,
get_strategy_from_args,
)
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_trainers),
"master_port": trainer_port,
"master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info
maker_port = str(get_free_port())
env_info_makers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_makers),
"master_port": maker_port,
"master_addr": master_addr,
}
for rank in range(args.num_makers)
]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
def model_fn():
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = (
get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
)
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = (
llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
kl_coef=0.1,
debug=args.debug,
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
use_cache=True,
)
for i, env_info_maker in enumerate(env_info_makers)
]
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = (
get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
.half()
.cuda()
)
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
f"maker{x}"
for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
eval_performance=True,
debug=args.debug,
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
return dataloader
# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
# for trainer_ref in trainer_refs
# ])
wait_tasks = []
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(
experience_holder_ref.workingloop.remote(
partial(build_dataloader, dataset_size), num_steps=args.experience_steps
)
)
total_steps = (
args.experience_batch_size
* args.experience_steps
* args.num_makers
// (args.num_trainers * args.train_batch_size)
)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_makers", type=int, default=1)
parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
"--trainer_strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default="ddp",
)
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)
from .prompt_dataset import PromptDataset from .prompt_dataset import PromptDataset
from .reward_dataset import HhRlhfDataset, RmStaticDataset from .reward_dataset import HhRlhfDataset, RmStaticDataset
from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0 from .utils import is_rank_0
__all__ = [ __all__ = [
'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset', "RmStaticDataset",
'DataCollatorForSupervisedDataset', 'PromptDataset' "HhRlhfDataset",
"SFTDataset",
"SupervisedDataset",
"PromptDataset",
"is_rank_0",
] ]
# Copyright 2023 lm-sys@FastChat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from enum import Enum, auto
from typing import List
class SeparatorStyle(Enum):
ADD_EOS_TOKEN = auto()
@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
sep: str = "</s>"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": "
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
)
def dict(self):
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
}
conv = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_EOS_TOKEN,
sep="</s>",
)
default_conversation = conv
import copy
import random
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from typing import Dict
from typing import Callable, Dict, Sequence
import torch import torch
import torch.distributed as dist
import transformers import transformers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload from .utils import jload
logger = get_dist_logger()
class PromptDataset(Dataset): class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
def __init__(self, def __init__(
data_path: str, self,
tokenizer: transformers.PreTrainedTokenizer, data_path: str,
max_datasets_size: int = None, tokenizer: transformers.PreTrainedTokenizer,
max_length: int = 96): max_datasets_size: int = None,
max_length: int = 96,
):
super(PromptDataset, self).__init__() super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list) self.keyed_prompt = defaultdict(list)
logger.info("Loading data...") self.logger = get_dist_logger()
self.logger.info("Loading data...")
list_data_dict = jload(data_path) list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.") self.logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_datasets_size is not None: if max_datasets_size is not None:
logger.info(f"Limiting dataset to {max_datasets_size} examples.") self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size] list_data_dict = list_data_dict[:max_datasets_size]
for data_dict in list_data_dict: instructions = [data_dict["instruction"] for data_dict in list_data_dict]
token = tokenizer(data_dict["instruction"], tokens = tokenizer(
return_tensors='pt', instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
max_length=max_length, )
padding='max_length', for k, tensor in tokens.items():
truncation=True) self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
for k, tensor in token.items():
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
def __len__(self): def __len__(self):
return len(self.keyed_prompt) return len(self.keyed_prompt["input_ids"])
def __getitem__(self, i) -> Dict[str, torch.Tensor]: def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return {k: v[i] for k, v in self.keyed_prompt.items()} return {k: v[i] for k, v in self.keyed_prompt.items()}
...@@ -6,7 +6,7 @@ from tqdm import tqdm ...@@ -6,7 +6,7 @@ from tqdm import tqdm
from .utils import is_rank_0 from .utils import is_rank_0
# Dahaos/rm-static # Dahoas/rm-static
class RmStaticDataset(Dataset): class RmStaticDataset(Dataset):
""" """
Dataset for reward model Dataset for reward model
...@@ -20,44 +20,31 @@ class RmStaticDataset(Dataset): ...@@ -20,44 +20,31 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__() super().__init__()
self.chosen = [] self.end_token = tokenizer.eos_token if special_token is None else special_token
self.reject = []
if special_token is None: chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
self.end_token = tokenizer.eos_token chosen_token = tokenizer(
else: chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
self.end_token = special_token )
for data in tqdm(dataset, disable=not is_rank_0()): self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
prompt = data['prompt']
reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
chosen = prompt + data['chosen'] + self.end_token reject_token = tokenizer(
chosen_token = tokenizer(chosen, reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
max_length=max_length, )
padding="max_length", self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
def __len__(self): def __len__(self):
length = len(self.chosen) length = self.chosen["input_ids"].shape[0]
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ return (
"input_ids"], self.reject[idx]["attention_mask"] self.chosen["input_ids"][idx],
self.chosen["attention_mask"][idx],
self.reject["input_ids"][idx],
self.reject["attention_mask"][idx],
)
# Anthropic/hh-rlhf # Anthropic/hh-rlhf
...@@ -74,39 +61,28 @@ class HhRlhfDataset(Dataset): ...@@ -74,39 +61,28 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__() super().__init__()
self.chosen = [] self.end_token = tokenizer.eos_token if special_token is None else special_token
self.reject = []
if special_token is None: chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
self.end_token = tokenizer.eos_token chosen_token = tokenizer(
else: chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
self.end_token = special_token )
for data in tqdm(dataset, disable=not is_rank_0()): self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
chosen = data['chosen'] + self.end_token
chosen_token = tokenizer(chosen, reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
max_length=max_length, reject_token = tokenizer(
padding="max_length", reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
truncation=True, )
return_tensors="pt") self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
reject = data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
def __len__(self): def __len__(self):
length = len(self.chosen) length = self.chosen["input_ids"].shape[0]
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ return (
"input_ids"], self.reject[idx]["attention_mask"] self.chosen["input_ids"][idx],
self.chosen["attention_mask"][idx],
self.reject["input_ids"][idx],
self.reject["attention_mask"][idx],
)
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# limitations under the License. # limitations under the License.
import copy import copy
import random from typing import Dict, Optional, Sequence, Tuple
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
import torch import torch
import torch.distributed as dist from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
import transformers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
...@@ -31,16 +29,89 @@ logger = get_dist_logger() ...@@ -31,16 +29,89 @@ logger = get_dist_logger()
IGNORE_INDEX = -100 IGNORE_INDEX = -100
PROMPT_DICT = { PROMPT_DICT = {
"prompt_input": "prompt_input": (
("Below is an instruction that describes a task, paired with an input that provides further context. " "Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n" "Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
"prompt_no_input": ("Below is an instruction that describes a task. " ),
"Write a response that appropriately completes the request.\n\n" "prompt_no_input": (
"### Instruction:\n{instruction}\n\n### Response:"), "Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
} }
def _preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)]
sequences_token = tokenizer(
sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
sources_token = tokenizer(
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item()
pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][: pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
def _preprocess_chatglm(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preprocess the data by tokenizing.
None for attention mask, ChatGLM will calculate attention mask according to input ids
"""
labels = []
input_ids = []
for source, target in zip(sources, targets):
source_id = tokenizer.encode(text=source, add_special_tokens=False)
target_id = tokenizer.encode(text=target, add_special_tokens=False)
input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id)
# truncate
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
truncate_length = max(0, len(input_id) - max_length)
input_id = input_id[truncate_length:]
if truncate_length == len(source_id) + 1:
input_id = sp_token_list + input_id[1:]
elif truncate_length > len(source_id) + 1:
input_id = sp_token_list + input_id[2:]
context_length = input_id.index(tokenizer.bos_token_id)
mask_position = context_length - 1
label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
pad_len = max_length - len(input_id)
input_id = input_id + [tokenizer.pad_token_id] * pad_len
input_ids.append(input_id)
labels.append(label + [IGNORE_INDEX] * pad_len)
return torch.tensor(input_ids), torch.tensor(labels), None
class SFTDataset(Dataset): class SFTDataset(Dataset):
""" """
Dataset for sft model Dataset for sft model
...@@ -51,73 +122,45 @@ class SFTDataset(Dataset): ...@@ -51,73 +122,45 @@ class SFTDataset(Dataset):
max_length: max length of input max_length: max length of input
""" """
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None: def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
super().__init__() super().__init__()
self.input_ids = [] self.input_ids = []
for data in tqdm(dataset, disable=not is_rank_0()): sources = [data["prompt"] for data in dataset]
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
prompt_token = tokenizer(prompt,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.input_ids.append(prompt_token['input_ids'][0]) logger.info("Tokenizing inputs... This may take some time...")
self.labels = copy.deepcopy(self.input_ids) if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
sources, targets, tokenizer, max_length
)
else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
logger.info("Loaded dataset.")
def __len__(self): def __len__(self):
length = len(self.input_ids) length = self.input_ids.shape[0]
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) if self.attention_mask is not None:
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
else:
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class SupervisedDataset(Dataset): class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512): def __init__(
super(SupervisedDataset, self).__init__() self,
data_path: str,
tokenizer: PreTrainedTokenizer,
max_datasets_size: Optional[int] = None,
max_length: int = 512,
):
super().__init__()
logger.info("Loading data...") logger.info("Loading data...")
list_data_dict = jload(data_path) list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.") logger.info(f"Loaded {len(list_data_dict)} examples.")
...@@ -129,38 +172,27 @@ class SupervisedDataset(Dataset): ...@@ -129,38 +172,27 @@ class SupervisedDataset(Dataset):
logger.info("Formatting inputs...") logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [ sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
for example in list_data_dict for example in list_data_dict
] ]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...") logger.info("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer, max_length) if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
sources, targets, tokenizer, max_length
)
else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
self.input_ids = data_dict["input_ids"] logger.info("Loaded dataset.")
self.labels = data_dict["labels"]
def __len__(self): def __len__(self):
return len(self.input_ids) length = self.input_ids.shape[0]
return length
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: def __getitem__(self, idx):
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) if self.attention_mask is not None:
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
batch_first=True, else:
padding_value=self.tokenizer.pad_token_id) return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
from .base import ExperienceBuffer
from .naive import NaiveExperienceBuffer
__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"]
...@@ -4,12 +4,12 @@ from typing import Any ...@@ -4,12 +4,12 @@ from typing import Any
from coati.experience_maker.base import Experience from coati.experience_maker.base import Experience
class ReplayBuffer(ABC): class ExperienceBuffer(ABC):
"""Replay buffer base class. It stores experience. """Experience buffer base class. It stores experience.
Args: Args:
sample_batch_size (int): Batch size when sampling. sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
""" """
def __init__(self, sample_batch_size: int, limit: int = 0) -> None: def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
......
import random import random
import warnings
from typing import List from typing import List
import torch import torch
from coati.experience_maker.base import Experience from coati.experience_maker.base import Experience
from .base import ReplayBuffer from .base import ExperienceBuffer
from .utils import BufferItem, make_experience_batch, split_experience_batch from .utils import BufferItem, make_experience_batch, split_experience_batch
class NaiveReplayBuffer(ReplayBuffer): class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive replay buffer class. It stores experience. """Naive experience buffer class. It stores experience.
Args: Args:
sample_batch_size (int): Batch size when sampling. sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
""" """
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None: def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
super().__init__(sample_batch_size, limit) super().__init__(sample_batch_size, limit)
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}') self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
# TODO(ver217): add prefetch # TODO(ver217): add prefetch
self.items: List[BufferItem] = [] self.items: List[BufferItem] = []
@torch.no_grad() @torch.no_grad()
def append(self, experience: Experience) -> None: def append(self, experience: Experience) -> None:
if self.cpu_offload: if self.cpu_offload:
experience.to_device(torch.device('cpu')) experience.to_device(torch.device("cpu"))
items = split_experience_batch(experience) items = split_experience_batch(experience)
self.items.extend(items) self.items.extend(items)
if self.limit > 0: if self.limit > 0:
samples_to_remove = len(self.items) - self.limit samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0: if samples_to_remove > 0:
warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:] self.items = self.items[samples_to_remove:]
def clear(self) -> None: def clear(self) -> None:
......
...@@ -21,6 +21,7 @@ class BufferItem: ...@@ -21,6 +21,7 @@ class BufferItem:
"A" is the number of actions. "A" is the number of actions.
""" """
sequences: torch.Tensor sequences: torch.Tensor
action_log_probs: torch.Tensor action_log_probs: torch.Tensor
values: torch.Tensor values: torch.Tensor
...@@ -33,7 +34,7 @@ class BufferItem: ...@@ -33,7 +34,7 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]: def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0) batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)] batch_kwargs = [{} for _ in range(batch_size)]
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
for key in keys: for key in keys:
value = getattr(experience, key) value = getattr(experience, key)
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
...@@ -48,25 +49,25 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]: ...@@ -48,25 +49,25 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
return items return items
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
assert side in ('left', 'right') assert side in ("left", "right")
max_len = max(seq.size(0) for seq in sequences) max_len = max(seq.size(0) for seq in sequences)
padded_sequences = [] padded_sequences = []
for seq in sequences: for seq in sequences:
pad_len = max_len - seq.size(0) pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == 'left' else (0, pad_len) padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding)) padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0) return torch.stack(padded_sequences, dim=0)
def make_experience_batch(items: List[BufferItem]) -> Experience: def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {} kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask')) to_pad_keys = set(("action_log_probs", "action_mask"))
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
for key in keys: for key in keys:
vals = [getattr(item, key) for item in items] vals = [getattr(item, key) for item in items]
if key in to_pad_keys: if key in to_pad_keys:
batch_data = zero_pad_sequences(vals) batch_data = _zero_pad_sequences(vals)
else: else:
batch_data = torch.stack(vals, dim=0) batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data kwargs[key] = batch_data
......
from .base import Experience, ExperienceMaker from .base import Experience, ExperienceMaker
from .naive import NaiveExperienceMaker from .naive import NaiveExperienceMaker
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker'] __all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"]
...@@ -3,14 +3,13 @@ from dataclasses import dataclass ...@@ -3,14 +3,13 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn from coati.models.base import Actor, Critic, RewardModel
from coati.models.base import Actor
@dataclass @dataclass
class Experience: class Experience:
"""Experience is a batch of data. """Experience is a batch of data.
These data should have the the sequence length and number of actions. These data should have the sequence length and number of actions.
Left padding for sequences is applied. Left padding for sequences is applied.
Shapes of each tensor: Shapes of each tensor:
...@@ -24,6 +23,7 @@ class Experience: ...@@ -24,6 +23,7 @@ class Experience:
"A" is the number of actions. "A" is the number of actions.
""" """
sequences: torch.Tensor sequences: torch.Tensor
action_log_probs: torch.Tensor action_log_probs: torch.Tensor
values: torch.Tensor values: torch.Tensor
...@@ -58,20 +58,13 @@ class Experience: ...@@ -58,20 +58,13 @@ class Experience:
class ExperienceMaker(ABC): class ExperienceMaker(ABC):
def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
def __init__(self,
actor: Actor,
critic: nn.Module,
reward_model: nn.Module,
initial_model: Actor,
kl_coef: float = 0.1) -> None:
super().__init__() super().__init__()
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
self.reward_model = reward_model self.reward_model = reward_model
self.initial_model = initial_model self.initial_model = initial_model
self.kl_coef = kl_coef
@abstractmethod @abstractmethod
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
pass pass
import torch import torch
from coati.models.utils import compute_reward, normalize import torch.nn.functional as F
from coati.models.base import Actor, Critic, RewardModel
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from transformers import PreTrainedTokenizer
from .base import Experience, ExperienceMaker from .base import Experience, ExperienceMaker
...@@ -9,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker): ...@@ -9,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker):
Naive experience maker. Naive experience maker.
""" """
def __init__(
self,
actor: Actor,
critic: Critic,
reward_model: RewardModel,
initial_model: Actor,
tokenizer: PreTrainedTokenizer,
kl_coef: float = 0.1,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef
@torch.no_grad() @torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
self.actor.eval() self.actor.eval()
...@@ -16,14 +33,33 @@ class NaiveExperienceMaker(ExperienceMaker): ...@@ -16,14 +33,33 @@ class NaiveExperienceMaker(ExperienceMaker):
self.initial_model.eval() self.initial_model.eval()
self.reward_model.eval() self.reward_model.eval()
sequences, attention_mask, action_mask = self.actor.generate(input_ids, # generate sequences
return_action_mask=True, sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
**generate_kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1) num_actions = action_mask.size(1)
action_log_probs = self.actor(sequences, num_actions, attention_mask) actor_output = self.actor(sequences, attention_mask)["logits"]
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
value = self.critic(sequences, action_mask, attention_mask) base_model_output = self.initial_model(sequences, attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(sequences, attention_mask)
r = self.reward_model(sequences, attention_mask) r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
......
from .wrapper import convert_to_xformer_model, recover_from_xformer_model from .wrapper import convert_to_xformer_model, recover_from_xformer_model
__all__ = [ __all__ = [
'convert_to_xformer_model', "convert_to_xformer_model",
'recover_from_xformer_model', "recover_from_xformer_model",
] ]
...@@ -21,11 +21,12 @@ class XOPTAttention(OPTAttention): ...@@ -21,11 +21,12 @@ class XOPTAttention(OPTAttention):
output_attentions: bool = False, output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
if not self.training: if not self.training:
return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, return super().forward(
output_attentions) hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
)
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask' assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
assert not output_attentions, 'Xformers attention does not support output_attentions' assert not output_attentions, "Xformers attention does not support output_attentions"
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
...@@ -69,15 +70,17 @@ class XOPTAttention(OPTAttention): ...@@ -69,15 +70,17 @@ class XOPTAttention(OPTAttention):
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
attn_output = xops.memory_efficient_attention(query_states, attn_output = xops.memory_efficient_attention(
key_states, query_states,
value_states, key_states,
attn_bias=xops.LowerTriangularMask(), value_states,
p=self.dropout if self.training else 0.0, attn_bias=xops.LowerTriangularMask(),
scale=self.scaling) p=self.dropout if self.training else 0.0,
scale=self.scaling,
)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism. # partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
......
from .base import Actor, Critic, RewardModel from .base import Actor, Critic, RewardModel
from .lora import LoRAModule, convert_to_lora_module from .lora import LoRAModule, convert_to_lora_module
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [ __all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss', "Actor",
'LoRAModule', 'convert_to_lora_module' "Critic",
"RewardModel",
"PolicyLoss",
"ValueLoss",
"LogSigLoss",
"LogExpLoss",
"LoRAModule",
"convert_to_lora_module",
] ]
from typing import Union
import torch.nn as nn import torch.nn as nn
from .actor import Actor from .actor import Actor
...@@ -5,10 +7,10 @@ from .critic import Critic ...@@ -5,10 +7,10 @@ from .critic import Critic
from .reward_model import RewardModel from .reward_model import RewardModel
def get_base_model(model: nn.Module) -> nn.Module: def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes. """Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``. For Actor, Critic and RewardModel, return ``model.model``,
For Critic and RewardModel, it's base model is itself. it's usually a ``transformers.PreTrainedModel``.
Args: Args:
model (nn.Module): model to get base model from model (nn.Module): model to get base model from
...@@ -16,9 +18,10 @@ def get_base_model(model: nn.Module) -> nn.Module: ...@@ -16,9 +18,10 @@ def get_base_model(model: nn.Module) -> nn.Module:
Returns: Returns:
nn.Module: the base model nn.Module: the base model
""" """
if isinstance(model, Actor): assert isinstance(
return model.get_base_model() model, (Actor, Critic, RewardModel)
return model ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first."
return model.model
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model'] __all__ = ["Actor", "Critic", "RewardModel", "get_base_model"]
from typing import Optional, Tuple, Union from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ..generation import generate
from ..lora import LoRAModule from ..lora import LoRAModule
from ..utils import log_probs_from_logits
class Actor(LoRAModule): class Actor(LoRAModule):
...@@ -19,47 +16,18 @@ class Actor(LoRAModule): ...@@ -19,47 +16,18 @@ class Actor(LoRAModule):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model self.model = model
self.convert_to_lora() self.convert_to_lora()
@torch.no_grad() def forward(
def generate(
self, self,
input_ids: torch.Tensor, input_ids: torch.LongTensor,
return_action_mask: bool = True, attention_mask: Optional[torch.Tensor] = None,
**kwargs **model_kwargs,
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: ) -> torch.Tensor:
sequences = generate(self.model, input_ids, **kwargs) """Returns model output."""
attention_mask = None output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
pad_token_id = kwargs.get('pad_token_id', None) return output
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
def forward(self,
sequences: torch.LongTensor,
num_actions: int,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def get_base_model(self):
return self.model
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..lora import LoRAModule from ..lora import LoRAModule
from ..utils import masked_mean
class Critic(LoRAModule): class Critic(LoRAModule):
...@@ -19,36 +16,19 @@ class Critic(LoRAModule): ...@@ -19,36 +16,19 @@ class Critic(LoRAModule):
""" """
def __init__( def __init__(
self, self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
model: nn.Module,
value_head: nn.Module,
lora_rank: int = 0,
lora_train_bias: str = 'none',
use_action_mask: bool = False,
) -> None: ) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model self.model = model
self.value_head = value_head self.value_head = value_head
self.use_action_mask = use_action_mask
self.convert_to_lora() self.convert_to_lora()
def forward(self, def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
sequences: torch.LongTensor,
action_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask) outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state'] last_hidden_states = outputs["last_hidden_state"]
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
values = self.value_head(last_hidden_states).squeeze(-1) 0
]
if action_mask is not None and self.use_action_mask: sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
num_actions = action_mask.size(1) values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
prompt_mask = attention_mask[:, :-num_actions] return values
values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value
values = values[:, :-1]
value = values.mean(dim=1)
return value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment