Unverified Commit df5e9c53 authored by YeAnbang's avatar YeAnbang Committed by GitHub
Browse files

[ColossalChat] Update RLHF V2 (#5286)



* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
parent 36c4bb28
import os
from transformers import AutoTokenizer
from utils import ChatPromptProcessor, Dialogue
CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"])
samples = [
(
[
Dialogue(
instruction="Who is the best player in the history of NBA?",
response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
),
Dialogue(instruction="continue this talk", response=""),
],
128,
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
),
(
[
Dialogue(
instruction="Who is the best player in the history of NBA?",
response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
),
Dialogue(instruction="continue this talk", response=""),
],
200,
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
),
(
[
Dialogue(
instruction="Who is the best player in the history of NBA?",
response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
),
Dialogue(instruction="continue this talk", response=""),
],
211,
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n",
),
(
[
Dialogue(instruction="Who is the best player in the history of NBA?", response=""),
],
128,
"Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n",
),
]
def test_chat_prompt_processor():
processor = ChatPromptProcessor(tokenizer, CONTEXT, 256)
for history, max_new_tokens, result in samples:
prompt = processor.preprocess_prompt(history, max_new_tokens)
assert prompt == result
if __name__ == "__main__":
test_chat_prompt_processor()
import json
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
import jieba
import torch
import torch.distributed as dist
import torch.nn as nn
from pydantic import BaseModel, Field
try:
from transformers.generation_logits_process import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
processor_list.append(TopKLogitsWarper(top_k))
if top_p is not None and top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
return processor_list
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
if dist.is_initialized() and dist.get_world_size() > 1:
# consider DP
unfinished_sequences = unfinished_sequences.clone()
dist.all_reduce(unfinished_sequences)
return unfinished_sequences.max() == 0
def sample_streamingly(
model: nn.Module,
input_ids: torch.Tensor,
max_generate_tokens: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> Generator:
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(max_generate_tokens):
model_inputs = (
prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
)
outputs = model(**model_inputs)
next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
yield next_tokens
# update generated ids, model inputs for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished if early_stopping=True
if early_stopping and _is_sequence_finished(unfinished_sequences):
break
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "past_key_values" in outputs:
model_kwargs["past"] = outputs["past_key_values"]
else:
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return model_kwargs
class Dialogue(BaseModel):
instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
response: str = Field(example="")
def _format_dialogue(instruction: str, response: str = ""):
return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}"
STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S))
class ChatPromptProcessor:
SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
self.censored_words = set([word.lower() for word in censored_words])
# These will be initialized after the first call of preprocess_prompt()
self.context_len: Optional[int] = None
self.dialogue_placeholder_len: Optional[int] = None
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
if self.context_len is None:
self.context_len = len(self.tokenizer(self.context)["input_ids"])
if self.dialogue_placeholder_len is None:
self.dialogue_placeholder_len = len(
self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"]
)
prompt = self.context
# the last dialogue must be in the prompt
last_dialogue = history.pop()
# the response of the last dialogue is empty
assert last_dialogue.response == ""
if (
len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"])
+ max_new_tokens
+ self.context_len
>= self.max_len
):
# to avoid truncate placeholder, apply truncate to the original instruction
instruction_truncated = self.tokenizer(
last_dialogue.instruction,
add_special_tokens=False,
truncation=True,
max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len),
)["input_ids"]
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
prompt += _format_dialogue(instruction_truncated)
return prompt
res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"])
rows = []
for dialogue in history[::-1]:
text = _format_dialogue(dialogue.instruction, dialogue.response)
cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
if res_len - cur_len < 0:
break
res_len -= cur_len
rows.insert(0, text)
prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt
def postprocess_output(self, output: str) -> str:
output = STOP_PAT.sub("", output)
return output.strip()
def has_censored_words(self, text: str) -> bool:
if len(self.censored_words) == 0:
return False
intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0
class LockedIterator:
def __init__(self, it, lock: Lock) -> None:
self.lock = lock
self.it = iter(it)
def __iter__(self):
return self
def __next__(self):
with self.lock:
return next(self.it)
def load_json(path: str):
with open(path) as f:
return json.load(f)
#!/bin/bash
set -xue
echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies."
if [[ $# -ne 0 && "$1" == "verbose" ]]; then
STRATEGIES=(
'ddp'
'colossalai_gemini'
'colossalai_gemini_cpu'
'colossalai_zero2'
'colossalai_zero2_cpu'
'colossalai_zero1'
'colossalai_zero1_cpu'
)
else
STRATEGIES=(
'colossalai_zero2'
)
fi
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
BENCHMARKS_DIR=$BASE_DIR/benchmarks
echo "[Test]: testing benchmarks ..."
for strategy in ${STRATEGIES[@]}; do
torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \
--model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \
--num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \
--train_batch_size 2 --experience_batch_size 4
done
import os
import tempfile
from contextlib import nullcontext
import pytest
import torch
import torch.distributed as dist
from coati.models.gpt import GPTActor
from coati.models.utils import calc_action_log_probs
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use, spawn
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
def get_data(batch_size: int, seq_len: int = 10) -> dict:
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
data = get_data(batch_size)
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"]
action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1))
loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim)
def run_test_checkpoint(strategy_name: str, shard: bool):
if strategy_name == "ddp":
strategy = DDPStrategy()
elif strategy_name == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
elif strategy_name == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f"Unsupported strategy '{strategy_name}'")
with strategy.model_init_context():
actor = GPTActor(config=GPT_CONFIG).cuda()
actor_optim = HybridAdam(actor.parameters())
actor, actor_optim = strategy.prepare((actor, actor_optim))
train_step(strategy, actor, actor_optim)
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
with ctx as dirname:
rank0_dirname = [dirname]
dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0]
model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
strategy.save_model(actor, model_path)
optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
strategy.save_optimizer(actor_optim, optim_path)
dist.barrier()
strategy.load_model(actor, model_path, strict=False)
strategy.load_optimizer(actor_optim, optim_path)
dist.barrier()
train_step(strategy, actor, actor_optim)
def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
run_test_checkpoint(strategy_name, shard)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
@pytest.mark.parametrize("shard", [False, True])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size: int, strategy_name: str, shard: bool):
spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard)
if __name__ == "__main__":
test_checkpoint(2, "colossalai_gemini", shard=False)
import json
import os
import tempfile
from typing import Optional
import pytest
import torch
from coati.dataset.prompt_dataset import PromptDataset
from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
SFT_DATASET = [
{
"instruction": "Provide a list of the top 10 most popular mobile games in Asia",
"input": "",
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id": 0,
},
{
"instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
"input": "",
"output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
"id": 1,
},
{
"instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
"input": "",
"output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
"id": 2,
},
]
PROMPT_DATASET = [
{
"instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."',
"id": 0,
},
{"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1},
{"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2},
{"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3},
]
def make_tokenizer(model: str):
if model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
elif model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
elif model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.unk_token
elif model == "chatglm":
tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
else:
raise ValueError(f"Unsupported model '{model}'")
return tokenizer
def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str):
if model == "opt":
# NOTE: Contrary to GPT2, OPT adds the EOS token </s> to the beginning of every prompt.
assert input_ids_stripped[0] == tokenizer.eos_token_id
input_ids_stripped = input_ids_stripped[1:]
elif model == "llama":
assert input_ids_stripped[0] == tokenizer.bos_token_id
input_ids_stripped = input_ids_stripped[1:]
elif model == "chatglm":
assert input_ids_stripped[0] == tokenizer.bos_token_id
assert input_ids_stripped[-1] == tokenizer.eos_token_id
input_ids_stripped = input_ids_stripped[1:-1]
assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
assert input_ids_stripped != tokenizer.sep_token_id
assert input_ids_stripped != tokenizer.cls_token_id
if model == "chatglm":
assert torch.all(input_ids_stripped != tokenizer.mask_token_id)
else:
assert input_ids_stripped != tokenizer.mask_token_id
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("max_length", [32, 1024])
@pytest.mark.parametrize("max_datasets_size", [2])
def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "prompt_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
json.dump(PROMPT_DATASET, f)
tokenizer = make_tokenizer(model)
assert tokenizer.padding_side in ("left", "right")
prompt_dataset = PromptDataset(
data_path=os.path.join(tmp_dir, dataset_name),
tokenizer=tokenizer,
max_datasets_size=max_datasets_size,
max_length=max_length,
)
assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
for i in range(len(prompt_dataset)):
assert isinstance(prompt_dataset[i], dict)
assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"]
input_ids = prompt_dataset[i]["input_ids"]
attention_mask = prompt_dataset[i]["attention_mask"]
attention_mask = attention_mask.bool()
assert input_ids.shape == attention_mask.shape == torch.Size([max_length])
assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id)
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize(
["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)]
)
@pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024])
def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
data = load_dataset(dataset_path, data_dir=subset)
assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"])
train_data = data["train"].select(range(max_datasets_size))
test_data = data["test"].select(range(max_datasets_size))
tokenizer = make_tokenizer(model)
assert tokenizer.padding_side in ("left", "right")
if dataset_path == "Anthropic/hh-rlhf":
train_dataset = HhRlhfDataset(train_data, tokenizer, max_length)
test_dataset = HhRlhfDataset(test_data, tokenizer, max_length)
elif dataset_path == "Dahoas/rm-static":
train_dataset = RmStaticDataset(train_data, tokenizer, max_length)
test_dataset = RmStaticDataset(test_data, tokenizer, max_length)
else:
raise ValueError(f'Unsupported dataset "{dataset_path}"')
assert len(train_dataset) == len(test_dataset) == max_datasets_size
for i in range(max_datasets_size):
chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
else:
check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
assert torch.all(c_mask)
if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
else:
check_content(reject_ids.masked_select(r_mask), tokenizer, model)
assert torch.all(r_mask)
chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
c_mask = c_mask.to(torch.bool)
r_mask = r_mask.to(torch.bool)
if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
else:
check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
assert torch.all(c_mask)
if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
else:
check_content(reject_ids.masked_select(r_mask), tokenizer, model)
assert torch.all(r_mask)
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024])
def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int):
tokenizer = make_tokenizer(model)
if dataset_path == "yizhongw/self_instruct":
data = load_dataset(dataset_path, "super_natural_instructions")
train_data = data["train"].select(range(max_dataset_size))
sft_dataset = SFTDataset(train_data, tokenizer, max_length)
else:
with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "sft_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
json.dump(SFT_DATASET, f)
sft_dataset = SupervisedDataset(
tokenizer=tokenizer,
data_path=os.path.join(tmp_dir, dataset_name),
max_datasets_size=max_dataset_size,
max_length=max_length,
)
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
if isinstance(tokenizer, ChatGLMTokenizer):
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels"]
input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"]
assert input_ids.shape == labels.shape == torch.Size([max_length])
ignore_mask = labels == IGNORE_INDEX
assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
return
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
input_ids = sft_dataset[i]["input_ids"]
labels = sft_dataset[i]["labels"]
attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length])
if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
else:
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
assert torch.all(attention_mask)
ignore_mask = labels == IGNORE_INDEX
prompt_mask = torch.logical_and(ignore_mask, attention_mask)
check_content(input_ids.masked_select(prompt_mask), tokenizer, model)
assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id)
if __name__ == "__main__":
test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
test_reward_dataset(
model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256
)
test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128)
import copy
import os
import pytest
import torch
import torch.distributed as dist
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic
from coati.trainer.ppo import _set_default_generate_kwargs
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.testing import rerun_if_address_is_in_use, spawn
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
def get_data(batch_size: int, seq_len: int = 10) -> dict:
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def gather_and_equal(tensor: torch.Tensor) -> bool:
world_size = dist.get_world_size()
outputs = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(outputs, tensor.contiguous())
for t in outputs[1:]:
if not torch.equal(outputs[0], t):
return False
return True
def make_and_consume_experience(strategy):
EXPERIENCE_BATCH_SIZE = 4
SAMPLE_BATCH_SIZE = 2
if strategy == "ddp":
strategy = DDPStrategy()
elif strategy == "colossalai-zero2":
strategy = LowLevelZeroStrategy()
elif strategy == "colossalai-gemini":
strategy = GeminiStrategy(placement_policy="static")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
with strategy.model_init_context():
actor = GPTActor(config=GPT_CONFIG).cuda()
critic = GPTCritic(config=GPT_CONFIG).cuda()
initial_model = GPTActor(config=GPT_CONFIG).cuda()
reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda()
actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model)
class MockTokenizer:
def __init__(self):
self.padding_side = "left"
self.eos_token_id = 0
self.pad_token_id = 0
tokenizer = MockTokenizer()
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
generate_kwargs = dict(do_sample=True, max_length=16)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
# experience of all ranks should be the same
for _ in range(2):
data = get_data(EXPERIENCE_BATCH_SIZE)
assert gather_and_equal(data["input_ids"])
assert gather_and_equal(data["attention_mask"])
experience = experience_maker.make_experience(**data, do_sample=True, max_length=16)
assert gather_and_equal(experience.sequences)
assert gather_and_equal(experience.action_log_probs)
assert gather_and_equal(experience.values)
assert gather_and_equal(experience.reward)
assert gather_and_equal(experience.advantages)
assert gather_and_equal(experience.action_mask)
assert gather_and_equal(experience.attention_mask)
data_buffer.append(experience)
# data buffer's data should be the same
buffer_size = torch.tensor([len(data_buffer)], device="cuda")
assert gather_and_equal(buffer_size)
for item in data_buffer.items:
assert gather_and_equal(item.sequences)
assert gather_and_equal(item.action_log_probs)
assert gather_and_equal(item.values)
assert gather_and_equal(item.reward)
assert gather_and_equal(item.advantages)
assert gather_and_equal(item.action_mask)
assert gather_and_equal(item.attention_mask)
# dataloader of each rank should have the same size and different batch
dataloader = strategy.setup_dataloader(data_buffer)
dataloader_size = torch.tensor([len(dataloader)], device="cuda")
assert gather_and_equal(dataloader_size)
for experience in dataloader:
assert not gather_and_equal(experience.sequences)
assert not gather_and_equal(experience.action_log_probs)
assert not gather_and_equal(experience.values)
assert not gather_and_equal(experience.reward)
assert not gather_and_equal(experience.advantages)
# action mask and attention mask may be same
def run_dist(rank, world_size, port, strategy):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
make_and_consume_experience(strategy)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"])
@rerun_if_address_is_in_use()
def test_experience(world_size, strategy):
spawn(run_dist, world_size, strategy=strategy)
if __name__ == "__main__":
test_experience(2, "colossalai-zero2")
set -xue
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
EXAMPLES_DIR=$BASE_DIR/examples
echo "[Test]: testing inference ..."
# HACK: skip llama due to oom
for model in 'gpt2' 'bloom' 'opt'; do
python $EXAMPLES_DIR/inference.py --model $model
done
import copy
from typing import Any, Callable, Dict, Tuple
import pytest
import torch
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.chatglm import ChatGLMActor
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, masked_mean
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
@pytest.mark.parametrize(
"actor_maker",
[
lambda: BLOOMActor(),
lambda: GPTActor(),
# HACK: skip llama due to long execution time
# lambda: LlamaActor(),
lambda: OPTActor(),
],
)
@pytest.mark.parametrize(
"generate_kwargs",
[
{
"max_length": 64,
"use_cache": True,
"do_sample": True,
"temperature": 1.0,
"top_k": 50,
}
],
)
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
class MockTokenizer:
def __init__(self):
self.padding_side = "left"
self.eos_token_id = 0
self.pad_token_id = 0
actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
tokenizer = MockTokenizer()
sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
def test_utils():
fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
fn_output = masked_mean(dim=0, **fn_input)
assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0))
batch_size = 4
seq_len = 32
num_labels = 10
num_actions = 2
fn_input = {
"logits": torch.randn((batch_size, seq_len, num_labels)),
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
"num_actions": num_actions,
}
fn_output = calc_action_log_probs(**fn_input)
assert fn_output.shape == (batch_size, num_actions)
@pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4])
def test_lora(lora_rank: int, num_dim: int, num_layers: int):
model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
lora_model = convert_to_lora_module(model, lora_rank)
assert isinstance(lora_model, nn.ModuleList)
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert lora_model[i].lora_A.shape == (lora_rank, num_dim)
assert lora_model[i].lora_B.shape == (num_dim, lora_rank)
old_model = copy.deepcopy(lora_model)
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
optimizer = torch.optim.Adam(lora_model.parameters())
x = torch.randn(8, num_dim)
for i in range(num_layers):
x = lora_model[i](x)
loss = x.sum()
loss.backward()
optimizer.step()
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert not torch.allclose(
old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize(
"models_maker",
[
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()),
lambda: (ChatGLMActor(), None, None),
],
)
@torch.no_grad()
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
critic_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
rm_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
actor, critic, rm = models_maker()
if isinstance(actor, ChatGLMActor):
actor = actor.float()
tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
actor_input = {
"input_ids": torch.cat(
(
torch.randint(0, 100, (batch_size, seq_len // 2)),
chatglm_special_token,
torch.randint(0, 100, (batch_size, seq_len // 2 - 2)),
),
dim=1,
),
"attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)),
}
assert isinstance(actor, Actor)
get_base_model(actor)
actor_output = actor(**actor_input)
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
if critic:
assert isinstance(critic, Critic)
get_base_model(critic)
critic_output = critic(**critic_input)
assert critic_output.shape == (batch_size,)
if rm:
assert isinstance(rm, RewardModel)
get_base_model(rm)
rm_output = rm(**rm_input)
assert rm_output.shape == (batch_size,)
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100])
def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss()
loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels),
"labels": torch.randint(0, num_labels, (batch_size, seq_len)),
}
loss(**loss_input)
loss = PolicyLoss()
loss_input = {
"log_probs": torch.randn(
batch_size,
),
"old_log_probs": torch.randn(
batch_size,
),
"advantages": torch.randn(
batch_size,
),
}
loss(**loss_input)
loss = ValueLoss()
loss_input = {
"values": torch.randn(
batch_size,
),
"old_values": torch.randn(
batch_size,
),
"reward": torch.randn(
batch_size,
),
}
loss(**loss_input)
loss = LogSigLoss()
loss_input = {
"chosen_reward": torch.randn(
batch_size,
),
"reject_reward": torch.randn(
batch_size,
),
}
loss(**loss_input)
loss = LogExpLoss()
loss_input = {
"chosen_reward": torch.randn(
batch_size,
),
"reject_reward": torch.randn(
batch_size,
),
}
loss(**loss_input)
if __name__ == "__main__":
generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
test_utils()
test_lora(lora_rank=2, num_dim=8, num_layers=2)
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
test_loss(batch_size=8, seq_len=128, num_labels=100)
#!/usr/bin/env bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4
set -xu
if [ -z "$SFT_DATASET" ]; then
echo "Please set \$SFT_DATASET to the path to sft dataset."
exit 1
fi
if [ -z "$PROMPT_DATASET" ]; then
echo "Please set \$PROMPT_DATASET to the path to prompts csv."
exit 1
fi
if [ -z "$PRETRAIN_DATASET" ]; then
echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
exit 1
fi
NUM_RETRY=3
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
EXAMPLES_DIR=$BASE_DIR/examples
MODELS_DIR=$BASE_DIR/examples/models_config
MODELS=('gpt2' 'bloom' 'opt' 'llama')
STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')
export OMP_NUM_THREADS=8
# install requirements
pip install -r $EXAMPLES_DIR/requirements.txt
python $EXAMPLES_DIR/download_model.py --model-dir $MODELS_DIR --config-only
get_pretrain() {
local model=$1
if [[ $model == "gpt2" ]]; then
echo "gpt2"
elif [[ $model == "bloom" ]]; then
echo "bigscience/bloom-560m"
elif [[ $model == "opt" ]]; then
echo "facebook/opt-350m"
else
echo "Unknown model $model"
exit 1
fi
}
random_choice() {
local arr=("$@")
local len=${#arr[@]}
local idx=$((RANDOM % len))
echo ${arr[$idx]}
}
echo "[Test]: testing sft ..."
# FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=(
"gpt2-ddp"
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
)
GRAD_CKPTS=('' '--grad_checkpoint')
for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$strategy-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
echo "[Test]: Skipped $model-$strategy"
continue
fi
pretrain=$(get_pretrain $model)
pretrain_model=""
if [[ $lora_rank -gt 0 ]]; then
pretrain_model="--pretrain $pretrain"
fi
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_sft.py \
$pretrain_model --tokenizer $MODELS_DIR/$model \
--model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
--dataset $SFT_DATASET --max_datasets_size 8 \
--max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \
--save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
passed=$?
if [ $passed -eq 0 ]; then
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$strategy-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing reward model ..."
# FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=(
"gpt2-ddp"
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
)
LOSS_FNS=('log_sig' 'log_exp')
DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$strategy-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
echo "[Test]: Skipped $model-$strategy"
continue
fi
pretrain=$(get_pretrain $model)
pretrain_model=""
if [[ $lora_rank -gt 0 ]]; then
pretrain_model="--pretrain $pretrain"
fi
loss_fn=$(random_choice "${LOSS_FNS[@]}")
dataset=$(random_choice "${DATASETS[@]}")
subset=$(if [[ $dataset == "Dahoas/rm-static" ]]; then echo "None"; else echo "harmless-base"; fi)
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
$pretrain_model --tokenizer $MODELS_DIR/$model \
--dataset $dataset --subset $subset --max_datasets_size 8 \
--model $model --strategy $strategy --lora_rank $lora_rank \
--loss_fn $loss_fn --batch_size 1 --lr 1e-8 \
--save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
passed=$?
if [ $passed -eq 0 ]; then
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed to train reward model $model-$strategy-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing RLHF ..."
# FIXME: This is a hack to skip tests that are not working
# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# - llama-*: These tests can be passed locally, skipped for long execution time
# - *-gemini: Gemini plugin does not support `from_pretrained` yet
SKIPPED_TESTS=(
"gpt2-ddp"
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
)
for model in ${MODELS[@]}; do
for lora_rank in '0'; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$strategy-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
echo "[Test]: Skipped $model-$strategy"
continue
fi
rm_pretrain=$(get_pretrain $model)
rm_pretrain_model=""
if [[ $lora_rank -gt 0 ]]; then
rm_pretrain_model="--rm_pretrain $rm_pretrain"
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
--prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \
--strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
--num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
passed=$?
if [ $passed -eq 0 ]; then
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed to train RLHF $model-$strategy-$lora_rank"
exit 1
fi
done
rm -rf $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
done
done
rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
...@@ -143,6 +143,17 @@ docs/.build ...@@ -143,6 +143,17 @@ docs/.build
*.pt *.pt
# wandb log # wandb log
example/wandb/ examples/wandb/
examples/logs/
examples/output/
examples/awesome-chatgpt-prompts/ examples/awesome-chatgpt-prompts/
temp/
# ColossalChat
applications/ColossalChat/logs
applications/ColossalChat/models
applications/ColossalChat/sft_data
applications/ColossalChat/prompt_data
applications/ColossalChat/preference_data
applications/ColossalChat/temp
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
- [Install the environment](#install-the-environment) - [Install the environment](#install-the-environment)
- [Install the Transformers](#install-the-transformers) - [Install the Transformers](#install-the-transformers)
- [How to use?](#how-to-use) - [How to use?](#how-to-use)
- [Supervised datasets collection](#supervised-datasets-collection) - [Supervised datasets collection](#step-1-data-collection)
- [RLHF Training Stage1 - Supervised instructs tuning](#RLHF-training-stage1---supervised-instructs-tuning) - [RLHF Training Stage1 - Supervised instructs tuning](#rlhf-training-stage1---supervised-instructs-tuning)
- [RLHF Training Stage2 - Training reward model](#RLHF-training-stage2---training-reward-model) - [RLHF Training Stage2 - Training reward model](#rlhf-training-stage2---training-reward-model)
- [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#RLHF-training-stage3---training-model-with-reinforcement-learning-by-human-feedback) - [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#rlhf-training-stage3---proximal-policy-optimization)
- [Inference Quantization and Serving - After Training](#inference-quantization-and-serving---after-training) - [Inference Quantization and Serving - After Training](#inference-quantization-and-serving---after-training)
- [Coati7B examples](#coati7b-examples) - [Coati7B examples](#coati7b-examples)
- [Generation](#generation) - [Generation](#generation)
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
--- ---
## What is ColossalChat and Coati ? ## What Is ColossalChat And Coati ?
[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project. [ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project.
...@@ -91,107 +91,191 @@ More details can be found in the latest news. ...@@ -91,107 +91,191 @@ More details can be found in the latest news.
## Install ## Install
### Install the environment ### Install the Environment
```bash ```bash
conda create -n coati # Create new environment
conda activate coati conda create -n colossal-chat python=3.10.9 (>=3.8.7)
conda activate colossal-chat
# Install flash-attention
git clone -b v2.0.5 https://github.com/Dao-AILab/flash-attention.git
cd $FLASH_ATTENTION_ROOT/
pip install .
cd $FLASH_ATTENTION_ROOT/csrc/xentropy
pip install .
cd $FLASH_ATTENTION_ROOT/csrc/layer_norm
pip install .
cd $FLASH_ATTENTION_ROOT/csrc/rotary
pip install .
# Clone Colossalai
git clone https://github.com/hpcaitech/ColossalAI.git git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI/applications/Chat
# Install ColossalAI
cd $COLOSSAL_AI_ROOT
BUILD_EXT=1 pip install .
# Install ColossalChat
cd $COLOSSAL_AI_ROOT/applications/Chat
pip install . pip install .
``` ```
### Install the Transformers ## How To Use?
```bash ### RLHF Training Stage1 - Supervised Instructs Tuning
pip install transformers==4.30.2
```
## How to use? Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat. More details can be found in [example guideline](./examples/README.md).
### Supervised datasets collection #### Step 1: Data Collection
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo ```json
[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md). [
{"messages":
[
{
"from": "human",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
...
]
},
...
]
```
Here is how we collected the data #### Step 2: Preprocessing
Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization.
<p align="center"> In this code, we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the [example guideline](./examples/README.md) on how to format and tokenize data.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/data-collect.png" width=500/>
</p>
### RLHF Training Stage1 - Supervised instructs tuning #### Step 3: Training
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. More details can be found in [example guideline](./examples/README.md).
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model. ### RLHF Training Stage2 - Training Reward Model
You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning. Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg)
**Note**: the supervised dataset follows the following format, #### Step 1: Data Collection
Below shows the preference dataset format used in training the reward model.
```json ```json
[ [
{ {"context": [
"instruction": "Provide a list of the top 10 most popular mobile games in Asia", {
"input": "", "from": "human",
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", "content": "Introduce butterflies species in Oregon."
"id": 0 }
]
"chosen": [
{
"from": "assistant",
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
},
...
],
"rejected": [
{
"from": "assistant",
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
},
...
]
}, },
... ...
] ]
``` ```
### RLHF Training Stage2 - Training reward model #### Step 2: Preprocessing
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model
You can run the `examples/train_rm.sh` to start a reward model training. #### Step 3: Training
[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo) You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. More details can be found in [example guideline](./examples/README.md).
### RLHF Training Stage3 - Training model with reinforcement learning by human feedback ### RLHF Training Stage3 - Proximal Policy Optimization
Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process: In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/stage-3.jpeg" width=800/> <img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/stage-3.jpeg" width=800/>
</p> </p>
You can run the `examples/train_prompts.sh` to start training PPO with human feedback. #### Step 1: Data Collection
[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g) PPO uses two kind of training data--- the prompt data and the sft data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
**Note**: the required datasets follow the following format, ```json
[
- `pretrain dataset` {"messages":
[
```json {
[ "from": "human",
{ "content": "what are some pranks with a pen i can do?"
"instruction": "Provide a list of the top 10 most popular mobile games in Asia", }
"input": "", ...
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", ]
"id": 0 },
}, ]
... ```
]
``` #### Step 2: Data Preprocessing
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
- `prompt dataset`
#### Step 3: Training
```json You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. More detais can be found in [example guideline](./examples/README.md).
[
{ ```bash
"instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", --pretrain $PRETRAINED_MODEL_PATH \
"id": 0 --rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectual
}, --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
{ --rm_checkpoint_path $REWARD_MODEL_PATH \ # reward model checkpoint path
"instruction": "Write a descriptive paragraph about a memorable vacation you went on", --prompt_dataset ${prompt_dataset[@]} \ # List of string, the prompt dataset
"id": 1 --ptx_dataset ${ptx_dataset[@]} \ # List of string, the SFT data used in the SFT stage
}, --ptx_batch_size 1 \ # batch size for calculate ptx loss
... --ptx_coef 0.0 \ # none-zero if ptx loss is enable
] --num_episodes 2000 \ # number of episodes to train
``` --num_collect_steps 1 \
--num_update_steps 1 \
For more details, see [`examples/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples). --experience_batch_size 8 \
--train_batch_size 4 \
--accumulation_steps 2
```
Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic.
- Without tensor parallelism,
```
experience buffer size
= num_process * num_collect_steps * experience_batch_size
= train_batch_size * accumulation_steps * num_process
```
- With tensor parallelism,
```
num_tp_group = num_process / tp
experience buffer size
= num_tp_group * num_collect_steps * experience_batch_size
= train_batch_size * accumulation_steps * num_tp_group
```
## Alternative Option For RLHF: Direct Preference Optimization
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
### DPO Training Stage1 - Supervised Instructs Tuning
Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part.
### DPO Training Stage2 - DPO Training
#### Step 1: Data Collection & Preparation
For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training.
#### Step 2: Training
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).
### Inference Quantization and Serving - After Training ### Inference Quantization and Serving - After Training
...@@ -301,91 +385,60 @@ You can find more examples in this [repo](https://github.com/XueFuzhao/Instructi ...@@ -301,91 +385,60 @@ You can find more examples in this [repo](https://github.com/XueFuzhao/Instructi
We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format. We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format.
- Option 1: Save the model weights, model config and generation config (Note: tokenizer will not be saved) which can be loaded using HF's from_pretrained method.
```python ```python
from coati.models.llama import LlamaLM # if use lora, you can choose to merge lora weights before saving
from coati.trainer import SFTTrainer if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
model = LlamaLM(pretrained=args.pretrain)
tokenizer = AutoTokenizer.from_pretrained(args.pretrain) # NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
(model, optim) = strategy.prepare((model, optim)) model.eval()
trainer = SFTTrainer(model=model, # save model checkpoint after fitting on only rank0
strategy=strategy, booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
optim=optim,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps
)
trainer.fit()
# this saves in pytorch format
strategy.save_model(model, args.save_path, only_rank0=True)
# this saves in HF format
strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=tokenizer)
``` ```
- Option 2: Save the model weights, model config, generation config, as well as the optimizer, learning rate scheduler, running states (Note: tokenizer will not be saved) which are needed for resuming training.
```python
from coati.utils import save_checkpoint
# save model checkpoint after fitting on only rank0
save_checkpoint(
save_dir=actor_save_dir,
booster=actor_booster,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
epoch=0,
step=step,
batch_size=train_batch_size,
coordinator=coordinator,
)
```
To load the saved checkpoint
```python
from coati.utils import load_checkpoint
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=checkpoint_path,
booster=booster,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
)
```
</details> </details>
<details><summary><b>How to train with limited resources</b></summary> <details><summary><b>How to train with limited resources</b></summary>
Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs. Here are some suggestions that can allow you to train a 7B model on a single or multiple consumer-grade GPUs.
If you only have a single 24G GPU, you can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model.
```bash
// [INFO]: MAX GPU MEMORY ALLOCATED: 19148.9345703125 MB
torchrun --standalone --nproc_per_node=1 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy ddp \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 1 \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1 \
--lora_rank 16 \
--grad_checkpoint
```
`colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model. To maintain a descent batch size for gradient calculation, consider increase the accumulation_step and reduce the batch_size on each rank.
```bash If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be sufficient.
torchrun --standalone --nproc_per_node=1 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_gemini \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 1 \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1 \
--grad_checkpoint
```
If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows. `gemini` and `gemini-auto` can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. But that strategy doesn't support gradient accumulation.
```bash
torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2_cpu \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 1 \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1 \
--grad_checkpoint
```
If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
</details> </details>
## The Plan ## The Plan
...@@ -396,6 +449,8 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ ...@@ -396,6 +449,8 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
- [x] support inference - [x] support inference
- [x] support llama from [facebook](https://github.com/facebookresearch/llama) - [x] support llama from [facebook](https://github.com/facebookresearch/llama)
- [x] implement PPO-ptx fine-tuning - [x] implement PPO-ptx fine-tuning
- [x] support flash-attention
- [x] implement DPO fine-tuning
- [ ] integrate with Ray - [ ] integrate with Ray
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL), - [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain) - [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain)
...@@ -467,6 +522,7 @@ Coati is developed by ColossalAI Team: ...@@ -467,6 +522,7 @@ Coati is developed by ColossalAI Team:
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT. - [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development. - [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements. - [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored version with updated acceleration framework, LoRA, DPO and PPO.
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw) - [Zangwei Zheng](https://github.com/zhengzangw)
......
{
"chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"human_line_start": [
2
],
"human_line_end": [
2
],
"assistant_line_start": [
2
],
"assistant_line_end": [
2
],
"end_of_system_line_position": 0
}
...@@ -17,22 +17,21 @@ We provide various OPT models (string in parentheses is the corresponding model ...@@ -17,22 +17,21 @@ We provide various OPT models (string in parentheses is the corresponding model
We also provide various training strategies: We also provide various training strategies:
- ddp: torch DDP - gemini: ColossalAI GeminiPlugin with `placement_policy="cuda"`, like zero3
- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3 - gemini_auto: ColossalAI GeminiPlugin with `placement_policy="cpu"`, like zero3-offload
- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload - zero2: ColossalAI zero2
- colossalai_zero2: ColossalAI zero2 - zero2_cpu: ColossalAI zero2-offload
- colossalai_zero2_cpu: ColossalAI zero2-offload - 3d: ColossalAI HybridParallelPlugin with TP, DP support
- colossalai_zero1: ColossalAI zero1
- colossalai_zero1_cpu: ColossalAI zero1-offload
We only support `torchrun` to launch now. E.g.
## How to Run
```bash ```bash
# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size cd ../tests
torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py \ # Prepare data for benchmark
--model 125m --critic_model 125m --strategy ddp \ SFT_DATASET=/path/to/sft/data/ \
--experience_batch_size 1 --train_batch_size 1 --lora_rank 0 PROMPT_DATASET=/path/to/prompt/data/ \
# run Actor (OPT-1.3B) and Critic (OPT-350M) with lora_rank=4 on single-node 4-GPU PRETRAIN_DATASET=/path/to/ptx/data/ \
torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py \ PREFERENCE_DATASET=/path/to/preference/data \
--model 1.3b --critic_model 350m --strategy colossalai_zero2 --lora_rank 4 ./test_data_preparation.sh
# Start benchmark
./benchmark_ppo.sh
``` ```
Model=Opt-125m; lora_rank=0; plugin=zero2
Max CUDA memory usage: 26123.16 MB
Model=Opt-125m; lora_rank=0; plugin=zero2
Max CUDA memory usage: 26123.91 MB
facebook/opt-125m; 0; zero2
Performance summary:
Generate 768 samples, throughput: 188.48 samples/s, TFLOPS per GPU: 361.23
Train 768 samples, throughput: 448.38 samples/s, TFLOPS per GPU: 82.84
Overall throughput: 118.42 samples/s
Overall time per sample: 0.01 s
Make experience time per sample: 0.01 s, 62.83%
Learn time per sample: 0.00 s, 26.41%
facebook/opt-125m; 0; zero2
Performance summary:
Generate 768 samples, throughput: 26.32 samples/s, TFLOPS per GPU: 50.45
Train 768 samples, throughput: 71.15 samples/s, TFLOPS per GPU: 13.14
Overall throughput: 18.86 samples/s
Overall time per sample: 0.05 s
Make experience time per sample: 0.04 s, 71.66%
Learn time per sample: 0.01 s, 26.51%
"""
For becnhmarking ppo. Mudified from examples/training_scripts/train_ppo.py
"""
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
import torch.distributed as dist
from coati.dataset import (
DataCollatorForPromptDataset,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_conversation_template,
setup_distributed_dataloader,
)
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
from coati.trainer import PPOTrainer
from coati.trainer.callbacks import PerformanceEvaluator
from coati.trainer.utils import is_rank_0
from coati.utils import load_checkpoint, replace_with_flash_attention
from transformers import AutoTokenizer, OPTForCausalLM
from transformers.models.opt.configuration_opt import OPTConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def get_model_numel(model: torch.nn.Module, plugin: str, tp: int) -> int:
numel = sum(p.numel() for p in model.parameters())
if plugin == "3d" and tp > 1:
numel *= dist.get_world_size()
return numel
def get_gpt_config(model_name: str) -> OPTConfig:
model_map = {
"125m": OPTConfig.from_pretrained("facebook/opt-125m"),
"350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
"700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
"1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
"2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
"3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
"5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
"6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
"10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
"13b": OPTConfig.from_pretrained("facebook/opt-13b"),
}
try:
return model_map[model_name]
except KeyError:
raise ValueError(f'Unknown model "{model_name}"')
def benchmark_train(args):
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx = LazyInitContext(default_device=get_current_device()) if "gemini" in args.plugin else nullcontext()
booster_policy = None
with init_ctx:
actor = OPTForCausalLM(config=get_gpt_config(args.pretrain))
# Disable dropout
disable_dropout(actor)
ref_model = OPTForCausalLM(config=get_gpt_config(args.pretrain))
reward_model = RewardModel(config=get_gpt_config("350m"))
critic = Critic(config=get_gpt_config("350m"))
disable_dropout(critic)
actor_numel = get_model_numel(actor, args.plugin, args.tp)
critic_numel = get_model_numel(critic, args.plugin, args.tp)
initial_model_numel = get_model_numel(ref_model, args.plugin, args.tp)
reward_model_numel = get_model_numel(reward_model, args.plugin, args.tp)
performance_evaluator = PerformanceEvaluator(
actor_numel,
critic_numel,
initial_model_numel,
reward_model_numel,
enable_grad_checkpoint=False,
ignore_episodes=2,
train_config={"model": "facebook/opt-" + args.pretrain, "lora_rank": args.lora_rank, "plugin": args.plugin},
save_path="./benchmark_performance_summarization.txt",
)
if args.tp > 1:
if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
raise ValueError("Reward model and critic model must have the same architecture")
if reward_model.model.config.architectures[0] == "BloomForCausalLM":
from colossalai.shardformer.policies.bloom import BloomPolicy
booster_policy = BloomPolicy()
elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
from colossalai.shardformer.policies.llama import LlamaPolicy
booster_policy = LlamaPolicy()
elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
from colossalai.shardformer.policies.gpt2 import GPT2Policy
booster_policy = GPT2Policy()
elif reward_model.model.config.architectures[0] == "ChatGLMModel":
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
booster_policy = ChatGLMPolicy()
elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
from colossalai.shardformer.policies.opt import OPTPolicy
booster_policy = OPTPolicy()
else:
raise ValueError("Unknown model architecture for policy")
if args.lora_rank > 0:
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint and args.lora_rank == 0:
actor.gradient_checkpointing_enable()
critic.model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
if args.use_flash_attn:
replace_with_flash_attention(model=actor)
replace_with_flash_attention(model=critic)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
if os.path.exists(args.conversation_template_config):
conversation_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8"))
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
)
stop_token_ids = (
conversation_template.assistant_line_end if len(conversation_template.assistant_line_end) > 0 else None
)
else:
raise ValueError("Conversation template config is not provided or incorrect")
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
tokenizer.padding_side = "left" # left padding for generation (online learning)
# configure generation config
actor.generation_config.update(
pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id
)
# configure optimizer
coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}")
actor_optim = HybridAdam(
model_params=actor.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
coordinator.print_on_master(f"setting up optimizer for critic: lr={args.lr}, weight_decay={args.weight_decay}")
critic_optim = HybridAdam(
model_params=critic.parameters(),
lr=args.critic_lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
coordinator.print_on_master(f"prompt dataset size: {len(train_prompt_dataset)}")
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
train_prompt_dataloader = setup_distributed_dataloader(
dataset=train_prompt_dataset,
batch_size=args.experience_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
)
if len(args.pretrain_dataset) > 0:
train_pretrain_dataset = load_tokenized_dataset(
dataset_paths=args.pretrain_dataset, mode="train", mode_map=mode_map
)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
train_pretrain_dataloader = setup_distributed_dataloader(
dataset=train_pretrain_dataset,
batch_size=args.ptx_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
)
else:
train_pretrain_dataloader = None
if args.warmup_steps is None:
args.warmup_steps = int(0.025 * args.num_episodes)
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
actor_lr_scheduler = CosineAnnealingWarmupLR(
optimizer=actor_optim,
total_steps=args.num_episodes,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
critic_lr_scheduler = CosineAnnealingWarmupLR(
optimizer=critic_optim,
total_steps=args.num_episodes,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
precision=args.mixed_precision,
)
custom_plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
precision=args.mixed_precision,
custom_policy=booster_policy,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
if args.plugin != "3d":
custom_plugin = plugin
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
rm_booster = Booster(plugin=custom_plugin)
critic_booster = Booster(plugin=custom_plugin)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost(
model=actor,
optimizer=actor_optim,
lr_scheduler=actor_lr_scheduler,
dataloader=train_prompt_dataloader,
)
critic, critic_optim, _, _, critic_lr_scheduler = critic_booster.boost(
model=critic,
optimizer=critic_optim,
lr_scheduler=critic_lr_scheduler,
dataloader=train_prompt_dataloader,
)
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
sampler_start_idx = 0
start_step = 0
if args.rm_checkpoint_path is not None:
if "modeling" in args.rm_checkpoint_path:
rm_booster.load_model(reward_model, args.rm_checkpoint_path)
else:
_, _, _ = load_checkpoint(
load_dir=args.rm_checkpoint_path,
booster=rm_booster,
model=reward_model,
optimizer=None,
lr_scheduler=None,
)
coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}")
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
actor_booster.load_model(actor, args.checkpoint_path)
ref_booster.load_model(ref_model, args.checkpoint_path)
coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}")
else:
_, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=actor_booster,
model=actor,
optimizer=actor_optim,
lr_scheduler=actor_lr_scheduler,
)
_, _, _ = load_checkpoint(
load_dir=args.checkpoint_path,
booster=ref_booster,
model=ref_model,
optimizer=critic_optim,
lr_scheduler=critic_lr_scheduler,
)
assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler)
train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
coordinator.print_on_master(
f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
if args.critic_checkpoint_path is not None:
if "modeling" in args.critic_checkpoint_path:
critic_booster.load_model(critic, args.critic_checkpoint_path)
else:
_, _, _ = load_checkpoint(
load_dir=args.critic_checkpoint_path,
booster=critic_booster,
model=critic,
optimizer=critic_optim,
lr_scheduler=critic_lr_scheduler,
)
coordinator.print_on_master(f"Loaded critic checkpoint {args.critic_checkpoint_path}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
# configure trainer
trainer = PPOTrainer(
actor_booster,
critic_booster,
actor,
critic,
reward_model,
ref_model,
actor_optim,
critic_optim,
actor_lr_scheduler,
critic_lr_scheduler,
tokenizer=tokenizer,
stop_token_ids=stop_token_ids,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
train_batch_size=args.train_batch_size,
buffer_limit=args.num_collect_steps * args.experience_batch_size,
max_length=args.max_length,
max_new_tokens=args.max_seq_len,
use_cache=True,
do_sample=True,
temperature=0.7,
accumulation_steps=args.accumulation_steps,
save_dir=args.save_path,
save_interval=args.save_interval,
top_k=50,
use_tp=args.tp > 1,
offload_inference_models="gemini" not in args.plugin,
callbacks=[performance_evaluator],
coordinator=coordinator,
)
trainer.fit(
num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps,
prompt_dataloader=train_prompt_dataloader,
pretrain_dataloader=train_pretrain_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
actor.eval()
critic.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final actor model checkpoint")
actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
)
coordinator.print_on_master("Start saving final critic model checkpoint")
critic_booster.save_model(critic, os.path.join(trainer.critic_save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final critic model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
)
memory_consumption = torch.cuda.max_memory_allocated() / 1024**2
if is_rank_0():
with open("./benchmark_memory_consumption.txt", "a+") as f:
f.write(
f"Model=Opt-{args.pretrain}; lora_rank={args.lora_rank}; plugin={args.plugin}\nMax CUDA memory usage: {memory_consumption:.2f} MB\n"
)
coordinator.print_on_master(f"Max CUDA memory usage: {memory_consumption:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_dataset", nargs="+", default=[])
parser.add_argument("--pretrain_dataset", nargs="+", default=[])
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument(
"--conversation_template_config",
type=str,
default=None,
help="Path \
to save conversation template config files.",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--num_collect_steps", type=int, default=2)
parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument("--save_interval", type=int, default=1000)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--experience_batch_size", type=int, default=16)
parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--lora_train_bias", type=str, default="none")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.0)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--log_dir", default="logs", type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
benchmark_train(args)
#!/usr/bin/env bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
set -xu
NUM_RETRY=3
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
EXAMPLES_DIR=$BASE_DIR/examples
TEMP_DIR=$BASE_DIR/temp
MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config
# To benchmark different models, change the following line
# MODELS=('125m' '350m' '700m' '1.3b' '2.7b' '3.5b' '5.5b' '6.7b' '10b' '13b')
MODELS=('125m')
# To benchmark different strategies, change the following line
# PLUGINS=('zero2', 'zero2_cpu', '3d')
PLUGINS=('zero2')
LORA_RANK=('0')
export OMP_NUM_THREADS=8
rm ./benchmark_memory_consumption.txt
rm ./benchmark_performance_summarization.txt
# install requirements
pip install -r $EXAMPLES_DIR/requirements.txt
random_choice() {
local arr=("$@")
local len=${#arr[@]}
local idx=$((RANDOM % len))
echo ${arr[$idx]}
}
echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
)
GRAD_CKPTS=('' '--grad_checkpoint')
GRAD_CKPTS=('')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
plugins=($(shuf -e "${PLUGINS[@]}"))
for plugin in ${plugins[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$model
tokenizer_dir="facebook/opt-125m"
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
if [[ $plugin == "3d" ]]; then
tp='4'
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 9); do
prompt_dataset+=("$TEMP_DIR/benchmark/arrow/part-$split")
done
colossalai run --nproc_per_node 8 --master_port 28547 $BASE_DIR/benchmarks/benchmark_ppo.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_coef 0 \
--save_path $MODEL_SAVE_PATH \
--conversation_template_config ./Opt.json \
--lora_rank $lora_rank \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--max_seq_len 128 \
--max_length 512 \
--experience_batch_size 32 \
--train_batch_size 32 \
--accumulation_steps 1 \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--use_flash_attn \
--tp $tp \
--lr 2e-5 \
$grad_ckpt
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
SAVE_DIR=""
BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
EXAMPLES_DIR=$BASE_DIR/examples
SAVE_DIR=$BASE_DIR/temp/benchmark
rm -rf $SAVE_DIR
python $EXAMPLES_DIR/data_preparation_scripts/prepare_prompt_dataset.py --data_input_dirs "/home/yeanbang/data/dataset/sft_data/alpaca/data_preprocessed/train" \
--conversation_template_config ./Opt.json \
--tokenizer_dir "facebook/opt-125m" \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--num_samples_per_datafile 30
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment