"tests/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "a1ce02d740e5013a14225620d60133104a07e8fa"
Unverified Commit da4f7b85 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[chat] fix bugs and add unit tests (#4213)

* style: rename replay buffer

Experience replay is typically for off policy algorithms.
Use this name in PPO maybe misleading.

* fix: fix wrong zero2 default arg

* test: update experience tests

* style: rename zero_pad fn

* fix: defer init in CycledDataLoader

* test: add benchmark test

* style: rename internal fn of generation

* style: rename internal fn of lora

* fix: remove unused loss fn

* fix: remove unused utils fn

* refactor: remove generate_with_actor fn

* fix: fix type annotation

* test: add models tests

* fix: skip llama due to long execution time

* style: modify dataset

* style: apply formatter

* perf: update reward dataset

* fix: fix wrong IGNORE_INDEX in sft dataset

* fix: remove DataCollatorForSupervisedDataset

* test: add dataset tests

* style: apply formatter

* style: rename test_ci to test_train

* feat: add llama in inference

* test: add inference tests

* test: change test scripts directory

* fix: update ci

* fix: fix typo

* fix: skip llama due to oom

* fix: fix file mod

* style: apply formatter

* refactor: remove duplicated llama_gptq

* style: apply formatter

* to: update rm test

* feat: add tokenizer arg

* feat: add download model script

* test: update train tests

* fix: modify gemini load and save pretrained

* test: update checkpoint io test

* to: modify nproc_per_node

* fix: do not remove existing dir

* fix: modify save path

* test: add random choice

* fix: fix sft path

* fix: enlarge nproc_per_node to avoid oom

* fix: add num_retry

* fix: make lora config of rm and critic consistent

* fix: add warning about lora weights

* fix: skip some gpt2 tests

* fix: remove grad ckpt in rm and critic due to errors

* refactor: directly use Actor in train_sft

* test: add more arguments

* fix: disable grad ckpt when using lora

* fix: fix save_pretrained and related tests

* test: enable zero2 tests

* revert: remove useless fn

* style: polish code

* test: modify test args
parent 16bf4c02
#!/usr/bin/env bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
set -xue
if [ -z "$SFT_DATASET" ]; then
echo "Please set \$SFT_DATASET to the path to sft dataset."
exit 1
fi
if [ -z "$PROMPT_PATH" ]; then
echo "Please set \$PROMPT_PATH to the path to prompts csv."
exit 1
fi
if [ -z "$PRETRAIN_DATASET" ]; then
echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
exit 1
fi
BASE=$(realpath $(dirname $0))
export OMP_NUM_THREADS=8
# install requirements
pip install -r ${BASE}/requirements.txt
wandb init -m offline
# FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time
SKIPPED_TESTS=(
"gpt2-ddp"
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
)
# These tests are quick and do not have any dependencies
for model in 'gpt2' 'bloom' 'opt' 'llama'; do
for strategy in 'ddp' 'colossalai_gemini' 'colossalai_zero2'; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " ${model}-${strategy} " ]]; then
echo "[Test]: Skipped $model-$strategy"
continue
fi
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy $strategy --model $model \
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
--train_batch_size 2 --lora_rank 4
done
done
# train sft
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
--model 'gpt2' --strategy colossalai_zero2 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
--model 'opt' --strategy colossalai_zero2 --lora_rank 4 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
--model 'gpt2' --strategy ddp --lora_rank 4 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
# train rm
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'facebook/opt-350m' --model 'opt' \
--strategy colossalai_zero2 --loss_fn 'log_sig' \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 0 \
--save_path ${BASE}/rm_ckpt_opt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'gpt2' --model 'gpt2' \
--strategy colossalai_zero2 --loss_fn 'log_exp' \
--dataset 'Dahoas/rm-static' \
--test True --lora_rank 0 \
--save_path ${BASE}/rm_ckpt_gpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'gpt2' --model 'gpt2' \
--strategy ddp --loss_fn 'log_exp' \
--dataset 'Dahoas/rm-static' \
--test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
--strategy colossalai_zero2 --loss_fn 'log_sig' \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt
# train rl
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_zero2 --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'facebook/opt-350m' --model opt \
--rm_pretrain 'facebook/opt-350m' \
--rm_path ${BASE}/rm_ckpt_opt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/rm_ckpt_opt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_zero2 --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'gpt2' --model gpt2 \
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_gemini --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'gpt2' --model gpt2 \
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/rm_ckpt_gpt.pt
rm -rf ${BASE}/actor_checkpoint_prompts.pt
# 3080 doesn't support P2P, skip this test
# cd ${BASE}/ray && bash test_ci.sh && cd ${BASE}
import argparse import argparse
import warnings
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset from coati.dataset import PromptDataset, SupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
...@@ -29,6 +30,7 @@ def main(args): ...@@ -29,6 +30,7 @@ def main(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None: if args.rm_path is not None:
warnings.warn('LoRA weights should be merged with the model weights')
state_dict = torch.load(args.rm_path, map_location='cpu') state_dict = torch.load(args.rm_path, map_location='cpu')
with strategy.model_init_context(): with strategy.model_init_context():
...@@ -50,18 +52,18 @@ def main(args): ...@@ -50,18 +52,18 @@ def main(args):
rm_model_name = args.rm_model rm_model_name = args.rm_model
if rm_model_name == 'gpt2': if rm_model_name == 'gpt2':
reward_model = GPTRM(pretrained=args.rm_pretrain) reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'bloom': elif rm_model_name == 'bloom':
reward_model = BLOOMRM(pretrained=args.rm_pretrain) reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'opt': elif rm_model_name == 'opt':
reward_model = OPTRM(pretrained=args.rm_pretrain) reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
elif rm_model_name == 'llama': elif rm_model_name == 'llama':
reward_model = LlamaRM(pretrained=args.rm_pretrain) reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else: else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None: if args.rm_path is not None:
reward_model.load_state_dict(state_dict) reward_model.load_state_dict(state_dict, strict=False)
initial_model.to(torch.float16).to(torch.cuda.current_device()) initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device())
...@@ -89,7 +91,7 @@ def main(args): ...@@ -89,7 +91,7 @@ def main(args):
raise ValueError(f'Unsupported reward model "{rm_model_name}"') raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None: if args.rm_path is not None:
critic.load_state_dict(state_dict) critic.load_state_dict(state_dict, strict=False)
del state_dict del state_dict
if args.strategy != 'colossalai_gemini': if args.strategy != 'colossalai_gemini':
...@@ -106,23 +108,25 @@ def main(args): ...@@ -106,23 +108,25 @@ def main(args):
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained(
'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained(
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama': elif args.model == 'llama':
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>' tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384) prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
...@@ -144,8 +148,7 @@ def main(args): ...@@ -144,8 +148,7 @@ def main(args):
pretrain_dataloader = DataLoader(pretrain_dataset, pretrain_dataloader = DataLoader(pretrain_dataset,
shuffle=(pretrain_sampler is None), shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler, sampler=pretrain_sampler,
batch_size=args.ptx_batch_size, batch_size=args.ptx_batch_size)
collate_fn=data_collator)
# NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized. # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \
...@@ -197,6 +200,7 @@ if __name__ == '__main__': ...@@ -197,6 +200,7 @@ if __name__ == '__main__':
default='colossalai_zero2', default='colossalai_zero2',
help='strategy to use') help='strategy to use')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--rm_path', type=str, default=None) parser.add_argument('--rm_path', type=str, default=None)
......
...@@ -36,34 +36,39 @@ def train(args): ...@@ -36,34 +36,39 @@ def train(args):
# configure model # configure model
with strategy.model_init_context(): with strategy.model_init_context():
if args.model == 'bloom': if args.model == 'bloom':
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'opt': elif args.model == 'opt':
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'gpt2': elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'llama': elif args.model == 'llama':
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
model.to(torch.float16).to(torch.cuda.current_device())
if args.model_path is not None: if args.model_path is not None:
state_dict = torch.load(args.model_path) state_dict = torch.load(args.model_path)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model = model.to(torch.float16)
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained(
'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained(
"facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama': elif args.model == 'llama':
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -89,8 +94,8 @@ def train(args): ...@@ -89,8 +94,8 @@ def train(args):
data = load_dataset(args.dataset) data = load_dataset(args.dataset)
if args.test: if args.test:
train_data = data['train'].select(range(100)) train_data = data['train'].select(range(20))
eval_data = data['test'].select(range(10)) eval_data = data['test'].select(range(5))
else: else:
train_data = data['train'] train_data = data['train']
eval_data = data['test'] eval_data = data['test']
...@@ -177,6 +182,7 @@ if __name__ == '__main__': ...@@ -177,6 +182,7 @@ if __name__ == '__main__':
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='colossalai_zero2') default='colossalai_zero2')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--need_optim_ckpt', type=bool, default=False)
...@@ -184,7 +190,7 @@ if __name__ == '__main__': ...@@ -184,7 +190,7 @@ if __name__ == '__main__':
type=str, type=str,
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
default='Dahoas/rm-static') default='Dahoas/rm-static')
parser.add_argument('--subset', type=str, default=None) parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None)
parser.add_argument('--save_path', type=str, default='rm_ckpt') parser.add_argument('--save_path', type=str, default='rm_ckpt')
parser.add_argument('--max_epochs', type=int, default=1) parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--batch_size', type=int, default=1)
......
set_n_least_used_CUDA_VISIBLE_DEVICES() { set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"} local n=${1:-"9999"}
echo "GPU Memory Usage:" echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
| tail -n +2 \ tail -n +2 |
| nl -v 0 \ nl -v 0 |
| tee /dev/tty \ tee /dev/tty |
| sort -g -k 2 \ sort -g -k 2 |
| awk '{print $1}' \ awk '{print $1}' |
| head -n $n) head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:" echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
...@@ -16,9 +16,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { ...@@ -16,9 +16,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2 set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_reward_model.py \ torchrun --standalone --nproc_per_node=2 train_reward_model.py \
--pretrain <your pretrain path> \ --model 'bloom' \
--model 'bloom' \ --strategy colossalai_zero2 \
--strategy colossalai_zero2 \ --loss_fn 'log_sig' \
--loss_fn 'log_sig'\ --dataset 'Anthropic/hh-rlhf'
--save_path <your model saving path>\
--dataset 'Anthropic/hh-rlhf'\
import argparse import argparse
import math import math
import os import warnings
import loralib as lora
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset from coati.dataset import SFTDataset, SupervisedDataset
from coati.models import convert_to_lora_module from coati.models.bloom import BLOOMActor
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
from coati.trainer import SFTTrainer from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.trainer import get_scheduler from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
...@@ -31,8 +29,6 @@ def train(args): ...@@ -31,8 +29,6 @@ def train(args):
if args.strategy == 'ddp': if args.strategy == 'ddp':
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini': elif args.strategy == 'colossalai_gemini':
raise NotImplementedError(
'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
strategy = GeminiStrategy(placement_policy='cuda') strategy = GeminiStrategy(placement_policy='cuda')
elif args.strategy == 'colossalai_zero2': elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
...@@ -42,40 +38,49 @@ def train(args): ...@@ -42,40 +38,49 @@ def train(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model # configure model
if args.lora_rank > 0:
warnings.warn("Gradient checkpoint is disabled when using LoRA")
args.grad_checkpoint = False
with strategy.model_init_context(): with strategy.model_init_context():
if args.model == 'bloom': if args.model == 'bloom':
model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain), model = BLOOMActor(pretrained=args.pretrain,
args.lora_rank).half().cuda() lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'opt': elif args.model == 'opt':
model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda() model = OPTActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'gpt2': elif args.model == 'gpt2':
model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda() model = GPTActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
elif args.model == 'llama': elif args.model == 'llama':
model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain), model = LlamaActor(pretrained=args.pretrain,
args.lora_rank).half().cuda() lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
if args.grad_checkpoint:
model.gradient_checkpointing_enable() model.to(torch.float16).to(torch.cuda.current_device())
# configure tokenizer # configure tokenizer
if args.model == 'gpt2': if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained(
'gpt2' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom': elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') tokenizer = BloomTokenizerFast.from_pretrained(
'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.pretrain, "facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
padding_side="right",
use_fast=False,
)
tokenizer.eos_token = '</s>'
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
...@@ -111,7 +116,6 @@ def train(args): ...@@ -111,7 +116,6 @@ def train(args):
max_datasets_size=args.max_datasets_size, max_datasets_size=args.max_datasets_size,
max_length=args.max_len) max_length=args.max_len)
eval_dataset = None eval_dataset = None
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, train_sampler = DistributedSampler(train_dataset,
...@@ -135,14 +139,12 @@ def train(args): ...@@ -135,14 +139,12 @@ def train(args):
shuffle=(train_sampler is None), shuffle=(train_sampler is None),
sampler=train_sampler, sampler=train_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=data_collator,
pin_memory=True) pin_memory=True)
if eval_dataset is not None: if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset, eval_dataloader = DataLoader(eval_dataset,
shuffle=(eval_sampler is None), shuffle=(eval_sampler is None),
sampler=eval_sampler, sampler=eval_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=data_collator,
pin_memory=True) pin_memory=True)
else: else:
eval_dataloader = None eval_dataloader = None
...@@ -184,6 +186,7 @@ if __name__ == '__main__': ...@@ -184,6 +186,7 @@ if __name__ == '__main__':
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='colossalai_zero2') default='colossalai_zero2')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default=None) parser.add_argument('--dataset', type=str, default=None)
parser.add_argument('--max_datasets_size', type=int, default=None) parser.add_argument('--max_datasets_size', type=int, default=None)
......
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
torchrun --standalone --nproc_per_node=4 train_sft.py \ torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \ --pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \ --model 'llama' \
--strategy colossalai_zero2 \ --strategy colossalai_zero2 \
--log_interval 10 \ --log_interval 10 \
--save_path /path/to/Coati-7B \ --save_path /path/to/Coati-7B \
--dataset /path/to/data.json \ --dataset /path/to/data.json \
--batch_size 4 \ --batch_size 4 \
--accumulation_steps 8 \ --accumulation_steps 8 \
--lr 2e-5 \ --lr 2e-5 \
--max_datasets_size 512 \ --max_datasets_size 512 \
--max_epochs 1 \ --max_epochs 1
...@@ -4,8 +4,8 @@ import argparse ...@@ -4,8 +4,8 @@ import argparse
from time import time from time import time
import torch import torch
from llama_gptq import load_quant from coati.quant import llama_load_quant, low_resource_init
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
def generate_prompt(instruction, input=None): def generate_prompt(instruction, input=None):
...@@ -106,7 +106,10 @@ if __name__ == "__main__": ...@@ -106,7 +106,10 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(args.pretrained) tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
if args.quant == '4bit': if args.quant == '4bit':
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda() model.cuda()
else: else:
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
......
from .loader import load_quant
__all__ = [
'load_quant',
]
import torch
import torch.nn as nn
import transformers
from transformers import LlamaConfig, LlamaForCausalLM
from .model_utils import find_layers
from .quant import make_quant
def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int):
config = LlamaConfig.from_pretrained(pretrained)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
print(f'Loading model with {wbits} bits...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
import torch
import torch.nn as nn
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py
import math
import numpy as np
import torch
import torch.nn as nn
def quantize(x, scale, zero, maxq):
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
try:
import quant_cuda
except:
print('CUDA extension not installed.')
# Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__()
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))):
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize
self.register_buffer(
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
dtype=torch.int))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
self.register_buffer('bias', torch.zeros(outfeatures))
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False
def pack(self, linear, scales, zeros):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone()
if linear.bias is not None:
self.bias = linear.bias.clone()
intweight = []
for idx in range(self.infeatures):
g_idx = idx // self.groupsize
intweight.append(
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
elif self.bits == 3:
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i))
i += 10
qweight[row] |= intweight[i] << 30
row += 1
qweight[row] |= (intweight[i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
i += 10
qweight[row] |= intweight[i] << 31
row += 1
qweight[row] |= (intweight[i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
i += 10
row += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
elif self.bits == 3:
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
i += 10
qzeros[:, col] |= zeros[:, i] << 30
col += 1
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
i += 10
qzeros[:, col] |= zeros[:, i] << 31
col += 1
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
i += 1
for j in range(i, i + 10):
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
i += 10
col += 1
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
intermediate_dtype = torch.float32
if not self._initialized_quant_state:
# Do we even have a bias? Check for at least one non-zero element.
if self.bias is not None and bool(torch.any(self.bias != 0)):
# Then make sure it's the right type.
self.bias.data = self.bias.data.to(intermediate_dtype)
else:
self.bias = None
outshape = list(x.shape)
outshape[-1] = self.outfeatures
x = x.reshape(-1, x.shape[-1])
if self.bias is None:
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
else:
y = self.bias.clone().repeat(x.shape[0], 1)
output_dtype = x.dtype
x = x.to(intermediate_dtype)
if self.bits == 2:
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 3:
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 4:
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
elif self.bits == 8:
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
y = y.to(output_dtype)
return y.reshape(outshape)
def make_quant(module, names, bits, groupsize, name=''):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
...@@ -5,8 +5,7 @@ from locust import HttpUser, task ...@@ -5,8 +5,7 @@ from locust import HttpUser, task
samples = [[ samples = [[
dict( dict(
instruction='Who is the best player in the history of NBA?', instruction='Who is the best player in the history of NBA?',
response= response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
), ),
dict(instruction='continue this talk', response=''), dict(instruction='continue this talk', response=''),
], [ ], [
......
import argparse import argparse
import os import os
from threading import Lock from threading import Lock
from typing import Dict, Generator, List, Optional from typing import Generator, List, Optional
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException, Request from coati.quant import llama_load_quant, low_resource_init
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from llama_gptq import load_quant
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
...@@ -56,7 +56,7 @@ app.add_middleware( ...@@ -56,7 +56,7 @@ app.add_middleware(
def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
#TODO(ver217): streaming generation does not support repetition_penalty now # TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = { model_kwargs = {
'max_generate_tokens': max_new_tokens, 'max_generate_tokens': max_new_tokens,
'early_stopping': True, 'early_stopping': True,
...@@ -162,7 +162,10 @@ if __name__ == '__main__': ...@@ -162,7 +162,10 @@ if __name__ == '__main__':
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
if args.quant == '4bit': if args.quant == '4bit':
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size) with low_resource_init():
config = LlamaConfig.from_pretrained(args.pretrained)
model = LlamaForCausalLM(config)
model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda() model.cuda()
else: else:
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
......
...@@ -10,37 +10,34 @@ samples = [ ...@@ -10,37 +10,34 @@ samples = [
([ ([
Dialogue( Dialogue(
instruction='Who is the best player in the history of NBA?', instruction='Who is the best player in the history of NBA?',
response= response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
), ),
Dialogue(instruction='continue this talk', response=''), Dialogue(instruction='continue this talk', response=''),
], 128, ], 128,
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
), ),
([ ([
Dialogue( Dialogue(
instruction='Who is the best player in the history of NBA?', instruction='Who is the best player in the history of NBA?',
response= response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
), ),
Dialogue(instruction='continue this talk', response=''), Dialogue(instruction='continue this talk', response=''),
], 200, ], 200,
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
), ),
([ ([
Dialogue( Dialogue(
instruction='Who is the best player in the history of NBA?', instruction='Who is the best player in the history of NBA?',
response= response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
), ),
Dialogue(instruction='continue this talk', response=''), Dialogue(instruction='continue this talk', response=''),
], 211, ], 211,
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n' 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
), ),
([ ([
Dialogue(instruction='Who is the best player in the history of NBA?', response=''), Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
], 128, ], 128,
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n' 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
), ),
] ]
......
import json
import re import re
from threading import Lock from threading import Lock
from typing import Any, Callable, Generator, List, Optional from typing import Any, Callable, Generator, List, Optional
import json
import jieba
import jieba
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
...@@ -127,7 +127,7 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S)) ...@@ -127,7 +127,7 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
class ChatPromptProcessor: class ChatPromptProcessor:
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.' SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]): def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.context = context self.context = context
self.max_len = max_len self.max_len = max_len
...@@ -182,6 +182,7 @@ class ChatPromptProcessor: ...@@ -182,6 +182,7 @@ class ChatPromptProcessor:
intersection = set(jieba.cut(text.lower())) & self.censored_words intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0 return len(intersection) > 0
class LockedIterator: class LockedIterator:
def __init__(self, it, lock: Lock) -> None: def __init__(self, it, lock: Lock) -> None:
...@@ -195,6 +196,7 @@ class LockedIterator: ...@@ -195,6 +196,7 @@ class LockedIterator:
with self.lock: with self.lock:
return next(self.it) return next(self.it)
def load_json(path: str): def load_json(path: str):
with open(path) as f: with open(path) as f:
return json.load(f) return json.load(f)
\ No newline at end of file
#!/bin/bash
set -xue
echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies."
if [[ $# -ne 0 && "$1" == "verbose" ]]; then
STRATEGIES=(
'ddp'
'colossalai_gemini'
'colossalai_gemini_cpu'
'colossalai_zero2'
'colossalai_zero2_cpu'
'colossalai_zero1'
'colossalai_zero1_cpu'
)
else
STRATEGIES=(
'colossalai_zero2'
)
fi
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
BENCHMARKS_DIR=$BASE_DIR/benchmarks
echo "[Test]: testing benchmarks ..."
for strategy in ${STRATEGIES[@]}; do
torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \
--model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \
--num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \
--train_batch_size 2 --experience_batch_size 4
done
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from coati.models.gpt import GPTActor from coati.models.gpt import GPTActor
from coati.models.utils import calc_action_log_probs from coati.models.utils import calc_action_log_probs
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
...@@ -17,40 +17,41 @@ GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) ...@@ -17,40 +17,41 @@ GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
def get_data(batch_size: int, seq_len: int = 10) -> dict: def get_data(batch_size: int, seq_len: int = 10) -> dict:
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
def run_test_checkpoint(strategy): def train_step(strategy: Strategy,
BATCH_SIZE = 2 actor: GPTActor,
actor_optim: HybridAdam,
batch_size: int = 8):
data = get_data(batch_size)
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
actor_output = actor(data["input_ids"], data["attention_mask"])
action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1))
loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim)
if strategy == 'ddp':
def run_test_checkpoint(strategy_name: str,
shard: bool):
if strategy_name == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif strategy == 'colossalai_gemini': elif strategy_name == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif strategy == 'colossalai_zero2': elif strategy_name == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else: else:
raise ValueError(f'Unsupported strategy "{strategy}"') raise ValueError(f"Unsupported strategy '{strategy_name}'")
with strategy.model_init_context(): with strategy.model_init_context():
actor = GPTActor(config=GPT_CONFIG).cuda() actor = GPTActor(config=GPT_CONFIG).cuda()
actor_optim = HybridAdam(actor.parameters()) actor_optim = HybridAdam(actor.parameters())
actor, actor_optim = strategy.prepare((actor, actor_optim)) actor, actor_optim = strategy.prepare((actor, actor_optim))
def run_step(): train_step(strategy, actor, actor_optim)
data = get_data(BATCH_SIZE)
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
actor_output = actor(data['input_ids'], data['attention_mask'])
action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1))
loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim)
run_step()
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
...@@ -59,43 +60,47 @@ def run_test_checkpoint(strategy): ...@@ -59,43 +60,47 @@ def run_test_checkpoint(strategy):
dist.broadcast_object_list(rank0_dirname) dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0] rank0_dirname = rank0_dirname[0]
model_path = os.path.join(rank0_dirname, 'model.pt') model_path = os.path.join(
strategy.save_model(actor, model_path, only_rank0=True) rank0_dirname, "model" if shard else f"model.pt")
strategy.save_model(actor, model_path, only_rank0=not shard)
optim_path = os.path.join(rank0_dirname, f'optim.pt') optim_path = os.path.join(
strategy.save_optimizer(actor_optim, optim_path, only_rank0=True) rank0_dirname, "optim" if shard else "optim.pt")
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
# FIXME(cwher): Sharded optimizer checkpoint is not supported yet.
# at "ColossalAI/colossalai/checkpoint_io/general_checkpoint_io.py", line 62
# optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
# strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
dist.barrier() dist.barrier()
strategy.load_model(actor, model_path, strict=False) strategy.load_model(actor, model_path, strict=False)
strategy.load_optimizer(actor_optim, optim_path) strategy.load_optimizer(actor_optim, optim_path)
dist.barrier() dist.barrier()
run_step() train_step(strategy, actor, actor_optim)
def run_dist(rank, world_size, port, strategy): def run_dist(rank: int,
os.environ['RANK'] = str(rank) world_size: int,
os.environ['LOCAL_RANK'] = str(rank) port: int,
os.environ['WORLD_SIZE'] = str(world_size) strategy_name: str,
os.environ['MASTER_ADDR'] = 'localhost' shard: bool):
os.environ['MASTER_PORT'] = str(port) os.environ["RANK"] = str(rank)
run_test_checkpoint(strategy) os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
run_test_checkpoint(strategy_name, shard)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) @pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
@pytest.mark.parametrize("shard", [False, True])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size, strategy): def test_checkpoint(world_size: int,
spawn(run_dist, world_size, strategy=strategy) strategy_name: str,
shard: bool):
spawn(run_dist,
world_size,
strategy_name=strategy_name,
shard=shard)
if __name__ == '__main__': if __name__ == "__main__":
test_checkpoint(2, 'colossalai_zero2') test_checkpoint(2, "colossalai_gemini", shard=False)
import json
import os
import tempfile
from typing import Optional
import pytest
import torch
from coati.dataset.prompt_dataset import PromptDataset
from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
SFT_DATASET = [
{
"instruction": "Provide a list of the top 10 most popular mobile games in Asia",
"input": "",
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id": 0
},
{
"instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
"input": "",
"output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
"id": 1
},
{
"instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
"input": "",
"output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
"id": 2
},
]
PROMPT_DATASET = [
{
"instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
"id": 0
},
{
"instruction": "Write a descriptive paragraph about a memorable vacation you went on",
"id": 1
},
{
"instruction": "Write a persuasive essay arguing why homework should be banned in schools",
"id": 2
},
{
"instruction": "Create a chart comparing the statistics on student debt in the United States.",
"id": 3
},
]
def make_tokenizer(model: str):
if model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
elif model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
elif model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f"Unsupported model '{model}'")
return tokenizer
def check_content(input_ids_stripped: torch.Tensor,
tokenizer: PreTrainedTokenizer,
model: str):
if model == "opt":
# NOTE: Contrary to GPT2, OPT adds the EOS token </s> to the beginning of every prompt.
assert input_ids_stripped[0] == tokenizer.eos_token_id
input_ids_stripped = input_ids_stripped[1:]
elif model == "llama":
assert input_ids_stripped[0] == tokenizer.bos_token_id
input_ids_stripped = input_ids_stripped[1:]
assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
assert input_ids_stripped != tokenizer.sep_token_id
assert input_ids_stripped != tokenizer.cls_token_id
assert input_ids_stripped != tokenizer.mask_token_id
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("max_length", [32, 1024])
@pytest.mark.parametrize("max_datasets_size", [2])
def test_prompt_dataset(model: str,
max_datasets_size: int,
max_length: int):
with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "prompt_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
json.dump(PROMPT_DATASET, f)
tokenizer = make_tokenizer(model)
assert tokenizer.padding_side in ("left", "right")
prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name),
tokenizer=tokenizer,
max_datasets_size=max_datasets_size,
max_length=max_length)
assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
for i in range(len(prompt_dataset)):
assert isinstance(prompt_dataset[i], dict)
assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"]
input_ids = prompt_dataset[i]["input_ids"]
attention_mask = prompt_dataset[i]["attention_mask"]
attention_mask = attention_mask.bool()
assert input_ids.shape == attention_mask.shape == torch.Size([max_length])
assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id)
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize(["dataset_path", "subset"], [
("Anthropic/hh-rlhf", "harmless-base"),
("Dahoas/rm-static", None)
])
@pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024])
def test_reward_dataset(model: str,
dataset_path: str,
subset: Optional[str],
max_datasets_size: int,
max_length: int):
data = load_dataset(dataset_path, data_dir=subset)
assert max_datasets_size <= len(data["train"]) \
and max_datasets_size <= len(data["test"])
train_data = data["train"].select(range(max_datasets_size))
test_data = data["test"].select(range(max_datasets_size))
tokenizer = make_tokenizer(model)
assert tokenizer.padding_side in ("left", "right")
if dataset_path == "Anthropic/hh-rlhf":
train_dataset = HhRlhfDataset(train_data, tokenizer, max_length)
test_dataset = HhRlhfDataset(test_data, tokenizer, max_length)
elif dataset_path == "Dahoas/rm-static":
train_dataset = RmStaticDataset(train_data, tokenizer, max_length)
test_dataset = RmStaticDataset(test_data, tokenizer, max_length)
else:
raise ValueError(f'Unsupported dataset "{dataset_path}"')
assert len(train_dataset) == len(test_dataset) == max_datasets_size
for i in range(max_datasets_size):
chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
assert chosen_ids.shape == c_mask.shape == \
reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
else:
check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
assert torch.all(c_mask)
if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
else:
check_content(reject_ids.masked_select(r_mask), tokenizer, model)
assert torch.all(r_mask)
chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
assert chosen_ids.shape == c_mask.shape == \
reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
else:
check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
assert torch.all(c_mask)
if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
else:
check_content(reject_ids.masked_select(r_mask), tokenizer, model)
assert torch.all(r_mask)
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024])
def test_sft_dataset(model: str,
dataset_path: Optional[str],
max_dataset_size: int,
max_length: int):
tokenizer = make_tokenizer(model)
if dataset_path == "yizhongw/self_instruct":
data = load_dataset(dataset_path, "super_natural_instructions")
train_data = data["train"].select(range(max_dataset_size))
sft_dataset = SFTDataset(train_data, tokenizer, max_length)
else:
with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "sft_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
json.dump(SFT_DATASET, f)
sft_dataset = SupervisedDataset(tokenizer=tokenizer,
data_path=os.path.join(tmp_dir, dataset_name),
max_datasets_size=max_dataset_size,
max_length=max_length)
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"]
attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
assert input_ids.shape == labels.shape == \
attention_mask.shape == torch.Size([max_length])
if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
else:
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
assert torch.all(attention_mask)
ignore_mask = labels == IGNORE_INDEX
check_content(input_ids.masked_select(ignore_mask), tokenizer, model)
if __name__ == "__main__":
test_sft_dataset(model="bloom",
dataset_path="yizhongw/self_instruct",
max_dataset_size=2,
max_length=256)
test_reward_dataset(model="gpt2",
dataset_path="Anthropic/hh-rlhf",
subset="harmless-base",
max_datasets_size=8,
max_length=256)
test_prompt_dataset(model="opt",
max_datasets_size=2,
max_length=128)
...@@ -4,11 +4,12 @@ from copy import deepcopy ...@@ -4,11 +4,12 @@ from copy import deepcopy
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import NaiveExperienceMaker from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic from coati.models.gpt import GPTActor, GPTCritic
from coati.replay_buffer import NaiveReplayBuffer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy from coati.trainer.strategies import DDPStrategy, GeminiStrategy
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
...@@ -32,13 +33,15 @@ def gather_and_equal(tensor: torch.Tensor) -> bool: ...@@ -32,13 +33,15 @@ def gather_and_equal(tensor: torch.Tensor) -> bool:
return True return True
def run_test_data(strategy): def make_and_consume_experience(strategy):
EXPERIENCE_BATCH_SIZE = 4 EXPERIENCE_BATCH_SIZE = 4
SAMPLE_BATCH_SIZE = 2 SAMPLE_BATCH_SIZE = 2
if strategy == 'ddp': if strategy == 'ddp':
strategy = DDPStrategy() strategy = DDPStrategy()
elif strategy == 'colossalai': elif strategy == 'colossalai-zero2':
strategy = LowLevelZeroStrategy()
elif strategy == 'colossalai-gemini':
strategy = GeminiStrategy(placement_policy='cuda') strategy = GeminiStrategy(placement_policy='cuda')
else: else:
raise ValueError(f'Unsupported strategy "{strategy}"') raise ValueError(f'Unsupported strategy "{strategy}"')
...@@ -50,7 +53,7 @@ def run_test_data(strategy): ...@@ -50,7 +53,7 @@ def run_test_data(strategy):
reward_model = RewardModel(deepcopy(critic.model)).cuda() reward_model = RewardModel(deepcopy(critic.model)).cuda()
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
# experience of all ranks should be the same # experience of all ranks should be the same
for _ in range(2): for _ in range(2):
...@@ -69,12 +72,12 @@ def run_test_data(strategy): ...@@ -69,12 +72,12 @@ def run_test_data(strategy):
assert gather_and_equal(experience.advantages) assert gather_and_equal(experience.advantages)
assert gather_and_equal(experience.action_mask) assert gather_and_equal(experience.action_mask)
assert gather_and_equal(experience.attention_mask) assert gather_and_equal(experience.attention_mask)
replay_buffer.append(experience) data_buffer.append(experience)
# replay buffer's data should be the same # data buffer's data should be the same
buffer_size = torch.tensor([len(replay_buffer)], device='cuda') buffer_size = torch.tensor([len(data_buffer)], device='cuda')
assert gather_and_equal(buffer_size) assert gather_and_equal(buffer_size)
for item in replay_buffer.items: for item in data_buffer.items:
assert gather_and_equal(item.sequences) assert gather_and_equal(item.sequences)
assert gather_and_equal(item.action_log_probs) assert gather_and_equal(item.action_log_probs)
assert gather_and_equal(item.values) assert gather_and_equal(item.values)
...@@ -84,7 +87,7 @@ def run_test_data(strategy): ...@@ -84,7 +87,7 @@ def run_test_data(strategy):
assert gather_and_equal(item.attention_mask) assert gather_and_equal(item.attention_mask)
# dataloader of each rank should have the same size and different batch # dataloader of each rank should have the same size and different batch
dataloader = strategy.setup_dataloader(replay_buffer) dataloader = strategy.setup_dataloader(data_buffer)
dataloader_size = torch.tensor([len(dataloader)], device='cuda') dataloader_size = torch.tensor([len(dataloader)], device='cuda')
assert gather_and_equal(dataloader_size) assert gather_and_equal(dataloader_size)
for experience in dataloader: for experience in dataloader:
...@@ -102,17 +105,16 @@ def run_dist(rank, world_size, port, strategy): ...@@ -102,17 +105,16 @@ def run_dist(rank, world_size, port, strategy):
os.environ['WORLD_SIZE'] = str(world_size) os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port) os.environ['MASTER_PORT'] = str(port)
run_test_data(strategy) make_and_consume_experience(strategy)
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) @pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_data(world_size, strategy): def test_experience(world_size, strategy):
spawn(run_dist, world_size, strategy=strategy) spawn(run_dist, world_size, strategy=strategy)
if __name__ == '__main__': if __name__ == '__main__':
test_data(2, 'colossalai') test_experience(2, 'colossalai')
set -xue
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
EXAMPLES_DIR=$BASE_DIR/examples
echo "[Test]: testing inference ..."
# HACK: skip llama due to oom
for model in 'gpt2' 'bloom' 'opt'; do
python $EXAMPLES_DIR/inference.py --model $model
done
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment