Commit 5eaaba41 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add in 0524

parents
Pipeline #1017 failed with stages
in 0 seconds
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
import copy
import json
import torch
from torch.utils.data import Dataset
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
class InstructionDataset(Dataset):
def __init__(self, dataset_config, tokenizer, partition="train"):
self.ann = json.load(open(dataset_config.data_path))
if partition == "train":
self.ann = self.ann[200:]
else:
self.ann = self.ann[:200]
self.tokenizer = tokenizer
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
ann = self.ann[index]
if ann.get("input", "") == "":
prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
else:
prompt = PROMPT_DICT["prompt_input"].format_map(ann)
example = prompt + ann["output"]
prompt = torch.tensor(
self.tokenizer.encode(prompt), dtype=torch.int64
)
example = self.tokenizer.encode(example)
example.append(self.tokenizer.eos_token_id)
example = torch.tensor(
example, dtype=torch.int64
)
labels = copy.deepcopy(example)
labels[: len(prompt)] = -1
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = IGNORE_INDEX
return {
"input_ids": example.tolist(),
"labels": labels.tolist(),
"attention_mask":example_mask.tolist(),
}
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# For dataset details visit: https://huggingface.co/datasets/jfleg
# For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
from datasets import load_dataset
from pathlib import Path
from torch.utils.data import Dataset
class grammar(Dataset):
def __init__(
self,
tokenizer,
csv_name=None,
):
try:
self.dataset = load_dataset(
"csv",
data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
delimiter=",",
)
except Exception as e:
print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.")
raise e
# self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path)
# if num_samples:
# self.dataset = self.dataset.select(list(range(0, num_samples)))
self.tokenizer = tokenizer
self.print_text = False # print_text
def __len__(self):
return self.dataset["train"].shape[0]
def convert_to_features(self, example_batch):
# Create prompt and tokenize contexts and questions
if self.print_text:
print("Input Text: ", self.clean_text(example_batch["text"]))
input_ = example_batch["input"]
target_ = example_batch["target"]
prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
sample = {
"input_ids": prompt_ids + label_ids,
"attention_mask": [1] * len(prompt_ids + label_ids),
"labels": [-100] * len(prompt_ids) + label_ids
}
return sample
def __getitem__(self, index):
return self.convert_to_features(self.dataset["train"][int(index)])
def get_dataset(
dataset_config, tokenizer, csv_name=None
):
"""cover function for handling loading the working dataset"""
"""dataset loading"""
if csv_name is None:
currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
print(f"Loading dataset {currPath}")
csv_name = str(currPath)
dataset = grammar(
tokenizer=tokenizer,
csv_name=csv_name,
)
return dataset
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
"This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.\n",
"\n",
"Use this notebook to pull in datasets and apply pre-processing. Most grammar datasets unfortunately require preprocessing before being usable in training. (example - jfleg has 4 targets per input, so we have to rematch as 1:1 pairings) "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import csv\n",
"from datasets import load_metric, load_dataset\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"list_replacements = [\n",
" (\" .\", \".\"), \n",
" (\" ,\", \",\"),\n",
" (\" '\", \"'\"),\n",
" (\" ?\", \"?\"),\n",
" (\" !\", \"!\"),\n",
" (\" :\", \":\"),\n",
" (\" ;\", \";\"),\n",
" (\" n't\", \"n't\"),\n",
" (\" v\", \"v\"),\n",
" (\"2 0 0 6\", \"2006\"),\n",
" (\"5 5\", \"55\"),\n",
" (\"4 0 0\", \"400\"),\n",
" (\"1 7-5 0\", \"1750\"),\n",
" (\"2 0 %\", \"20%\"),\n",
" (\"5 0\", \"50\"),\n",
" (\"1 2\", \"12\"),\n",
" (\"1 0\", \"10\"),\n",
" ('\" ballast water', '\"ballast water')\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def correct_spacing(item):\n",
" \"\"\" we iterate through the list of all replacements per each item in dataset\"\"\"\n",
" for fix in list_replacements:\n",
" item = item.replace(fix[0], fix[1])\n",
" return item\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def generate_csv(csv_path, dataset):\n",
" \"\"\" apply spacing corrections and save out matched pairs to csv file as dataset\"\"\"\n",
" with open(csv_path, 'w', newline='') as csvfile:\n",
" writer = csv.writer(csvfile)\n",
" writer.writerow([\"input\", \"target\"])\n",
" for case in dataset:\n",
" \t # Adding the t5 task indication prefix to input \n",
" input_text = case[\"sentence\"]\n",
" input_text = correct_spacing(input_text)\n",
"\n",
" for correction in case[\"corrections\"]:\n",
" correction = correct_spacing(correction)\n",
" # a few of the cases contain blank strings. \n",
" if input_text and correction:\n",
" writer.writerow([input_text, correction])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In Jfleg - validation will be used as 'train', test will be 'validation'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n",
"Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n"
]
}
],
"source": [
"train_dataset = load_dataset(\"jfleg\", split='validation[:]') \n",
"eval_dataset = load_dataset(\"jfleg\", split='test[:]')\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset({\n",
" features: ['sentence', 'corrections'],\n",
" num_rows: 755\n",
"})\n",
"Dataset({\n",
" features: ['sentence', 'corrections'],\n",
" num_rows: 748\n",
"})\n"
]
}
],
"source": [
"print(train_dataset)\n",
"print(eval_dataset)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas . \n",
"['Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ']\n"
]
}
],
"source": [
"print(train_dataset['sentence'][22])\n",
"print(train_dataset['corrections'][22])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas. '"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clean22 = correct_spacing(train_dataset['sentence'][22])\n",
"clean22"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"jfleg_dir = Path.cwd()/'jfleg_dataset' # if you only use 'jfleg', hf will try and use that and complain\n",
"jfleg_dir.mkdir(parents=True,exist_ok=True)\n",
"c4_dir = Path.cwd()/'c4_dataset'\n",
"c4_dir.mkdir(parents=True,exist_ok=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Process Jfleg data "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"j_train_file = jfleg_dir/'jtrain.csv'\n",
"j_eval_file = jfleg_dir/'jeval.csv'"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"generate_csv(j_train_file, train_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"generate_csv(j_eval_file, eval_dataset)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Process C4_200M (!) - we'll pull 10K to start"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"c4_dataset = load_dataset(\"liweili/c4_200m\", streaming = True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"iterator = iter(c4_dataset['train'])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def c4_generate_csv(csv_path, iterator, num_examples):\n",
" with open(csv_path, 'w', newline='') as csvfile:\n",
" writer = csv.writer(csvfile)\n",
" writer.writerow([\"input\", \"target\"])\n",
" for i in range(0,num_examples):\n",
" data = next(iterator)\n",
" input_text = data[\"input\"]\n",
" input_text = correct_spacing(input_text)\n",
" correction = correct_spacing(data[\"output\"])\n",
" if input_text and correction:\n",
" writer.writerow([input_text, correction])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"c4_dir = Path.cwd()/'c4_dataset'\n",
"c4_dir.mkdir(parents=True,exist_ok=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"You can modify the following to make the csv file with desired number of instances, here we go for 10k to make a quick test"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"c4_filename = c4_dir/'c4train_10k.csv'"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"c4_generate_csv(c4_filename, iterator, num_examples=10000)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a single training file by combining jtrain and c4train"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"merge_list = [j_train_file, c4_filename, ]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"combined_csv = pd.concat([pd.read_csv(fn) for fn in merge_list])\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"merged_name = \"gtrain_10k.csv\""
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"combined_csv.to_csv(merged_name, index=False, encoding = 'utf-8-sig', )"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"eval_name = \"grammar_validation.csv\""
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"eval_csv = pd.read_csv(j_eval_file)\n",
"eval_csv.to_csv(eval_name, index=False, encoding = 'utf-8-sig', )"
]
}
],
"metadata": {
"interpreter": {
"hash": "5b2c14c5f2a3b21e6c2412c8196f5145870350e81c0b737cae3e5c60eb1e1eac"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# For dataset details visit: https://huggingface.co/datasets/samsum
import copy
import datasets
def get_preprocessed_samsum(dataset_config, tokenizer, split):
dataset = datasets.load_dataset("samsum", split=split)
prompt = (
f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
)
def apply_prompt_template(sample):
return {
"prompt": prompt.format(dialog=sample["dialogue"]),
"summary": sample["summary"],
}
dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
def tokenize_add_label(sample):
prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
sample = {
"input_ids": prompt + summary,
"attention_mask" : [1] * (len(prompt) + len(summary)),
"labels": [-100] * len(prompt) + summary,
}
return sample
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
return dataset
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import dataclasses
import fire
import random
import torch
import torch.optim as optim
from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy
)
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.optim.lr_scheduler import StepLR
from transformers import (
AutoTokenizer,
LlamaForCausalLM,
LlamaConfig,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
from llama_recipes.configs import train_config as TRAIN_CONFIG
from llama_recipes.data.concatenator import ConcatDataset
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
from llama_recipes.utils import fsdp_auto_wrap_policy
from llama_recipes.utils.config_utils import (
update_config,
generate_peft_config,
generate_dataset_config,
get_dataloader_kwargs,
)
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
from llama_recipes.utils.train_utils import (
train,
freeze_transformer_layers,
setup,
setup_environ_flags,
clear_gpu_cache,
print_model_size,
get_policies,
)
from accelerate.utils import is_xpu_available
def setup_wandb(train_config, fsdp_config, **kwargs):
try:
import wandb
except ImportError:
raise ImportError(
"You are trying to use wandb which is not currently installed. "
"Please install it using pip install wandb"
)
from llama_recipes.configs import wandb_config as WANDB_CONFIG
wandb_config = WANDB_CONFIG()
update_config(wandb_config, **kwargs)
init_dict = dataclasses.asdict(wandb_config)
run = wandb.init(**init_dict)
run.config.update(train_config)
run.config.update(fsdp_config, allow_val_change=True)
return run
def main(**kwargs):
# Update the configuration for the training and sharding process
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
update_config((train_config, fsdp_config), **kwargs)
# Set the seeds for reproducibility
if is_xpu_available():
torch.xpu.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)
random.seed(train_config.seed)
if train_config.enable_fsdp:
setup()
# torchrun specific
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.distributed.is_initialized():
if is_xpu_available():
torch.xpu.set_device(local_rank)
elif torch.cuda.is_available():
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)
wandb_run = None
if train_config.use_wandb:
if not train_config.enable_fsdp or rank==0:
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
# Load the pre-trained model and setup its configuration
use_cache = False if train_config.enable_fsdp else None
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
"""
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
this avoids cpu oom when loading large models like llama 70B, in which case
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
overhead and currently requires latest nightly.
"""
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
)
else:
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
llama_config.use_cache = use_cache
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
)
# Load the tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
# If there is a mismatch between tokenizer vocab size and embedding matrix,
# throw a warning and then expand the embedding matrix
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
model.resize_token_embeddings(len(tokenizer))
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
# Prepare the model for int8 training if quantization is enabled
if train_config.quantization:
model = prepare_model_for_kbit_training(model)
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if train_config.enable_fsdp and fsdp_config.pure_bf16:
model.to(torch.bfloat16)
if train_config.use_peft:
# Load the pre-trained peft model checkpoint and setup its configuration
if train_config.from_peft_checkpoint:
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
peft_config = model.peft_config()
# Generate the peft config and start fine-tuning from original model
else:
peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config)
if wandb_run:
wandb_run.config.update(peft_config)
model.print_trainable_parameters()
hsdp_device_mesh = None
if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
print("HSDP device mesh is ready")
#setting up FSDP if enable_fsdp is enabled
if train_config.enable_fsdp:
if not train_config.use_peft and train_config.freeze_layers:
freeze_transformer_layers(model, train_config.num_freeze_layers)
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
device_id = 0
if is_xpu_available():
device_id = torch.xpu.current_device()
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()
model = FSDP(
model,
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_mesh=hsdp_device_mesh,
device_id=device_id,
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
if train_config.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
if is_xpu_available():
model.to("xpu:0")
elif torch.cuda.is_available():
model.to("cuda")
dataset_config = generate_dataset_config(train_config, kwargs)
# Load and preprocess the dataset for training and validation
dataset_train = get_preprocessed_dataset(
tokenizer,
dataset_config,
split="train",
)
if not train_config.enable_fsdp or rank == 0:
print(f"--> Training Set Length = {len(dataset_train)}")
dataset_val = get_preprocessed_dataset(
tokenizer,
dataset_config,
split="test",
)
if not train_config.enable_fsdp or rank == 0:
print(f"--> Validation Set Length = {len(dataset_val)}")
if train_config.batching_strategy == "packing":
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
# Create DataLoaders for the training and validation dataset
train_dataloader = torch.utils.data.DataLoader(
dataset_train,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**train_dl_kwargs,
)
eval_dataloader = None
if train_config.run_validation:
if train_config.batching_strategy == "packing":
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
eval_dataloader = torch.utils.data.DataLoader(
dataset_val,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**val_dl_kwargs,
)
# Initialize the optimizer and learning rate scheduler
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
optimizer = AnyPrecisionAdamW(
model.parameters(),
lr=train_config.lr,
momentum_dtype=torch.bfloat16,
variance_dtype=torch.bfloat16,
use_kahan_summation=False,
weight_decay=train_config.weight_decay,
)
else:
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.lr,
weight_decay=train_config.weight_decay,
)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
# Start the training process
results = train(
model,
train_dataloader,
eval_dataloader,
tokenizer,
optimizer,
scheduler,
train_config.gradient_accumulation_steps,
train_config,
fsdp_config if train_config.enable_fsdp else None,
local_rank if train_config.enable_fsdp else None,
rank if train_config.enable_fsdp else None,
wandb_run,
)
if not train_config.enable_fsdp or rank==0:
[print(f'Key: {k}, Value: {v}') for k, v in results.items()]
if train_config.use_wandb:
for k,v in results.items():
wandb_run.summary[k] = v
if __name__ == "__main__":
fire.Fire(main)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import json
def read_dialogs_from_file(file_path):
with open(file_path, 'r') as file:
dialogs = json.load(file)
return dialogs
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import fire
import os
import sys
import yaml
from transformers import AutoTokenizer
from llama_recipes.inference.model_utils import load_llama_from_config
# Get the current file's directory
current_directory = os.path.dirname(os.path.abspath(__file__))
# Get the parent directory
parent_directory = os.path.dirname(current_directory)
# Append the parent directory to sys.path
sys.path.append(parent_directory)
from model_checkpointing import load_sharded_model_single_gpu
def main(
fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints
consolidated_model_path="", # Path to save the HF converted model checkpoints
HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf)
):
try:
file_name = 'train_params.yaml'
# Combine the directory and file name to create the full path
train_params_path = os.path.join(fsdp_checkpoint_path, file_name)
# Open the file
with open(train_params_path, 'r') as file:
# Load the YAML data
data = yaml.safe_load(file)
# Access the 'model_name' field
HF_model_path_or_name = data.get('model_name')
print(f"Model name: {HF_model_path_or_name}")
except FileNotFoundError:
print(f"The file {train_params_path} does not exist.")
HF_model_path_or_name = input("Please enter the model name: ")
print(f"Model name: {HF_model_path_or_name}")
except Exception as e:
print(f"An error occurred: {e}")
#load the HF model definition from config
model_def = load_llama_from_config(HF_model_path_or_name)
print("model is loaded from config")
#load the FSDP sharded checkpoints into the model
model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path)
print("model is loaded from FSDP checkpoints")
#loading the tokenizer form the model_path
tokenizer = AutoTokenizer.from_pretrained(HF_model_path_or_name)
tokenizer.save_pretrained(consolidated_model_path)
#save the FSDP sharded checkpoints in HF format
model.save_pretrained(consolidated_model_path)
print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}")
if __name__ == "__main__":
fire.Fire(main)
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import logging
import time
from abc import ABC, abstractmethod
from typing import Callable
import openai
from typing_extensions import override
NUM_LLM_RETRIES = 10
MAX_TOKENS = 1000
TEMPERATURE = 0.1
TOP_P = 0.9
LOG: logging.Logger = logging.getLogger(__name__)
class LLM(ABC):
def __init__(self, model: str, api_key: str | None = None) -> None:
if model not in self.valid_models():
LOG.warning(
f"{model} is not in the valid model list for {type(self).__name__}. Valid models are: {', '.join(self.valid_models())}."
)
self.model: str = model
self.api_key: str | None = api_key
@abstractmethod
def query(self, prompt: str) -> str:
"""
Abstract method to query an LLM with a given prompt and return the response.
Args:
prompt (str): The prompt to send to the LLM.
Returns:
str: The response from the LLM.
"""
pass
def query_with_system_prompt(self, system_prompt: str, prompt: str) -> str:
"""
Abstract method to query an LLM with a given prompt and system prompt and return the response.
Args:
system prompt (str): The system prompt to send to the LLM.
prompt (str): The prompt to send to the LLM.
Returns:
str: The response from the LLM.
"""
return self.query(system_prompt + "\n" + prompt)
def _query_with_retries(
self,
func: Callable[..., str],
*args: str,
retries: int = NUM_LLM_RETRIES,
backoff_factor: float = 0.5,
) -> str:
last_exception = None
for retry in range(retries):
try:
return func(*args)
except Exception as exception:
last_exception = exception
sleep_time = backoff_factor * (2**retry)
time.sleep(sleep_time)
LOG.debug(
f"LLM Query failed with error: {exception}. Sleeping for {sleep_time} seconds..."
)
raise RuntimeError(
f"Unable to query LLM after {retries} retries: {last_exception}"
)
def query_with_retries(self, prompt: str) -> str:
return self._query_with_retries(self.query, prompt)
def query_with_system_prompt_with_retries(
self, system_prompt: str, prompt: str
) -> str:
return self._query_with_retries(
self.query_with_system_prompt, system_prompt, prompt
)
def valid_models(self) -> list[str]:
"""List of valid model parameters, e.g. 'gpt-3.5-turbo' for GPT"""
return []
class OPENAI(LLM):
"""Accessing OPENAI"""
def __init__(self, model: str, api_key: str) -> None:
super().__init__(model, api_key)
self.client = openai.OpenAI(api_key=api_key) # noqa
@override
def query(self, prompt: str) -> str:
# Best-level effort to suppress openai log-spew.
# Likely not work well in multi-threaded environment.
level = logging.getLogger().level
logging.getLogger().setLevel(logging.WARNING)
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "user", "content": prompt},
],
max_tokens=MAX_TOKENS,
)
logging.getLogger().setLevel(level)
return response.choices[0].message.content
@override
def valid_models(self) -> list[str]:
return ["gpt-3.5-turbo", "gpt-4"]
class ANYSCALE(LLM):
"""Accessing ANYSCALE"""
def __init__(self, model: str, api_key: str) -> None:
super().__init__(model, api_key)
self.client = openai.OpenAI(base_url="https://api.endpoints.anyscale.com/v1", api_key=api_key) # noqa
@override
def query(self, prompt: str) -> str:
# Best-level effort to suppress openai log-spew.
# Likely not work well in multi-threaded environment.
level = logging.getLogger().level
logging.getLogger().setLevel(logging.WARNING)
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "user", "content": prompt},
],
max_tokens=MAX_TOKENS,
)
logging.getLogger().setLevel(level)
return response.choices[0].message.content
@override
def valid_models(self) -> list[str]:
return [
"meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-13b-chat-hf",
"meta-llama/Llama-2-70b-chat-hf",
"codellama/CodeLlama-34b-Instruct-hf",
"mistralai/Mistral-7B-Instruct-v0.1",
"HuggingFaceH4/zephyr-7b-beta",
]
class OctoAI(LLM):
"""Accessing OctoAI"""
def __init__(self, model: str, api_key: str) -> None:
super().__init__(model, api_key)
self.client = openai.OpenAI(base_url="https://text.octoai.run/v1", api_key=api_key) # noqa
@override
def query(self, prompt: str) -> str:
# Best-level effort to suppress openai log-spew.
# Likely not work well in multi-threaded environment.
level = logging.getLogger().level
logging.getLogger().setLevel(logging.WARNING)
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant. Keep your responses limited to one short paragraph if possible."},
{"role": "user", "content": prompt},
],
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
)
logging.getLogger().setLevel(level)
return response.choices[0].message.content
@override
def valid_models(self) -> list[str]:
return [
"llamaguard-2-8b",
"meta-llama-3-8b-instruct",
"meta-llama-3-70b-instruct",
]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
# Function to load the main model for text generation
def load_model(model_name, quantization, use_fast_kernels):
print(f"use_fast_kernels{use_fast_kernels}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
return_dict=True,
load_in_8bit=quantization,
device_map="auto",
low_cpu_mem_usage=True,
attn_implementation="sdpa" if use_fast_kernels else None,
)
return model
# Function to load the PeftModel for performance optimization
def load_peft_model(model, peft_model):
peft_model = PeftModel.from_pretrained(model, peft_model)
return peft_model
# Loading the model from config to load FSDP checkpoints into that
def load_llama_from_config(config_path):
model_config = LlamaConfig.from_pretrained(config_path)
model = LlamaForCausalLM(config=model_config)
return model
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from dataclasses import dataclass
from string import Template
from typing import List
from enum import Enum
class LlamaGuardVersion(Enum):
LLAMA_GUARD_1 = "Llama Guard 1"
LLAMA_GUARD_2 = "Llama Guard 2"
class AgentType(Enum):
AGENT = "Agent"
USER = "User"
@dataclass
class SafetyCategory:
name: str
description: str
@dataclass
class ConversationTurn:
message: str
agent_type: AgentType
LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX = "O"
LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX = "S"
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
<BEGIN UNSAFE CONTENT CATEGORIES>
$categories
<END UNSAFE CONTENT CATEGORIES>
"""
PROMPT_CONVERSATION = """
<BEGIN CONVERSATION>
$conversations
<END CONVERSATION>
"""
PROMPT_INSTRUCTIONS_1 = """
Provide your safety assessment for $agent_type in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_INSTRUCTIONS_2 = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE_1 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_1} [/INST]")
PROMPT_TEMPLATE_2 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_2} [/INST]")
LLAMA_GUARD_1_CATEGORY = [
SafetyCategory(
"Violence and Hate.",
"""Should not
- Help people plan or engage in violence.
- Encourage people to engage in violence.
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
Can
- Provide information on violence and discrimination.
- Discuss the topics of hate and violence, and can discuss historical events involving violence."""),
SafetyCategory(
"Sexual Content.",
"""Should not
- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
Can
- Discuss topics related to sexuality and sexual education.
- Provide general information regarding sexual activities.
- Provide guidance on topics related to sexual health."""),
SafetyCategory(
"Criminal Planning.",
"""Should not
- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
- Encourage people to engage in criminal activities or condone criminal behavior.
Can
- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems."""),
SafetyCategory(
"Guns and Illegal Weapons.",
"""Should not
- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
Can
- Discuss firearms and the arguments for and against firearm ownership."""),
SafetyCategory(
"Regulated or Controlled Substances.",
"""Should not
- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
- Assist or encourage people to create such substances when it is illegal to do so.
Can
- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation."""),
SafetyCategory(
"Self-Harm.",
"""Should not
- Encourage people to harm themselves.
- Romanticize or condone self-harm or suicide.
- Provide information regarding the methods of suicide or self-harm.
- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
Can
- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
Should
- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""),
]
LLAMA_GUARD_2_CATEGORY = [
SafetyCategory(
"Violent Crimes. ",
"""AI models should not create content that enables, encourages, or excuses the commission of violent crimes. Examples of violent crimes include, but are not limited to:
- Unlawful violence toward people (ex: terrorism, genocide, murder, hate crimes, child abuse, assault, battery, kidnapping)
- Unlawful violence toward animals (ex: animal abuse)"""),
SafetyCategory(
"Non-Violent Crimes. ",
"""AI models should not create content that enables, encourages, or excuses the commission of non-violent crimes. Examples of non-violent crimes include, but are not limited to:
- Financial crimes (ex: fraud, scams, money laundering)
- Property crimes (ex: burglary, robbery, arson, vandalism)
- Drug crimes (ex: creating or trafficking narcotics)
- Weapons crimes (ex: producing unlicensed firearms)
- Cyber crimes (ex: hacking, spyware, malware)"""),
SafetyCategory(
"Sex Crimes. ",
"""AI models should not create content that enables, encourages, or excuses the commission of sex crimes. Examples of sex crimes include, but are not limited to:
- Human trafficking
- Sexual assault (ex: rape)
- Sexual harassment (ex: groping)
- Lewd conduct"""),
SafetyCategory(
"Child Exploitation. ",
"AI models should not create content that depicts child nudity or that enables, encourages, excuses, or depicts the sexual abuse of children."),
SafetyCategory(
"Specialized Advice. ",
"AI models should not create content that contains specialized financial, medical, or legal advice."),
SafetyCategory(
"Privacy. ",
"AI models should not create content that contains sensitive, personal information about private individuals."),
SafetyCategory(
"Intellectual Property. ",
"AI models should not create content that violates the intellectual property rights of third parties."),
SafetyCategory(
"Indiscriminate Weapons. ",
"""AI models should not create content that enables the creation of weapons of mass destruction. Examples of such weapons include, but are not limited to:
- Chemical weapons (ex: nerve gas)
- Biological weapons (ex: anthrax)
- Radiological weapons (ex: salted bombs)
- Nuclear weapons (ex: atomic warheads)
- High-yield explosive weapons (ex: cluster munitions)"""),
SafetyCategory(
"Hate. ",
"AI models should not create content that is hateful toward people on the basis of their protected characteristics (race, color, ethnicity, national origin, disability, religious affiliation, caste, sexual orientation, sex, gender identity and serious disease). AI models should not create content that perpetuates negative stereotypes."),
SafetyCategory(
"Self-Harm. ",
"""AI models should not create content that enables, encourages, or excuses acts of intentional self-harm. Examples of such actions include, but are not limited to:
- Suicide
- Self-injury (ex: cutting)
- Disordered eating"""),
SafetyCategory(
"Sexual Content. ",
"AI models should not create depictions of nude adults or content that contains erotic descriptions or explicit depictions of sex acts."),
]
def create_conversation(messges: List[str]) -> List[ConversationTurn]:
conversations = []
for i, messge in enumerate(messges):
conversations.append(ConversationTurn(message=messge, agent_type=AgentType.USER if i % 2 == 0 else AgentType.AGENT))
return conversations
def build_default_prompt(
agent_type: AgentType,
conversations: List[ConversationTurn],
llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_2):
if llama_guard_version == LlamaGuardVersion.LLAMA_GUARD_2:
categories = LLAMA_GUARD_2_CATEGORY
category_short_name_prefix = LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX
prompt_template = PROMPT_TEMPLATE_2
else:
categories = LLAMA_GUARD_1_CATEGORY
category_short_name_prefix = LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX
prompt_template = PROMPT_TEMPLATE_1
return build_custom_prompt(
agent_type,
conversations,
categories,
category_short_name_prefix,
prompt_template)
def build_custom_prompt(
agent_type: AgentType,
conversations: List[ConversationTurn],
categories: List[SafetyCategory],
category_short_name_prefix: str,
prompt_template: str,
with_policy: bool = False):
categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}" + (f"\n{c.description}" if with_policy else "") for i, c in enumerate(categories)])
conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations])
return prompt_template.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str)
def build_prompt_test():
print(build_default_prompt(AgentType.AGENT,
[
ConversationTurn("Whats the color of the sky?", AgentType.USER),
ConversationTurn("The sky is blue.", AgentType.AGENT)
]))
print("\n\n")
# use a customized safety category and create_conversation function.
print(build_custom_prompt(
AgentType.AGENT,
create_conversation(
[
"<User Prompt placeholder>",
"<Agent Prompt placeholder>"
]),
[
SafetyCategory("Violence and Hate.","""Should not
- Help people plan or engage in violence.
- Encourage people to engage in violence.
- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
Can
- Provide information on violence and discrimination.
- Discuss the topics of hate and violence, and can discuss historical events involving violence.""",
),],
LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX,
PROMPT_TEMPLATE_2,
True
)
)
if __name__ == "__main__":
build_prompt_test()
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import torch
import warnings
from typing import List
from string import Template
from enum import Enum
class AgentType(Enum):
AGENT = "Agent"
USER = "User"
# Class for performing safety checks using AuditNLG library
class AuditNLGSensitiveTopics(object):
def __init__(self):
pass
def __call__(self, output_text, **kwargs):
try:
from auditnlg.safety.exam import safety_scores
except ImportError as e:
print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt")
raise e
data = [{"output": output_text}]
result = safety_scores(data=data, method="sensitive_topics")
scores = result[1]["all_scores"][0]
is_safe = scores["pred_class"] == "none"
report = ""
if not is_safe:
report += f"Predicted class: {scores['pred_class']}\n"
report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n"
report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n"
return "Sensitive Topics", is_safe, report
class SalesforceSafetyChecker(object):
def __init__(self):
pass
def __call__(self, output_text, **kwargs):
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base")
tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config)
model.eval()
prefix = "Is the <Text> field safe or unsafe?"
input_ids = tokenizer(prefix + " <Text> " + output_text + " <Context> ", return_tensors="pt").input_ids
if len(input_ids[0]) > 512:
warnings.warn(
"Input length is > 512 token. Safety check result could be incorrect."
)
with torch.no_grad():
outputs = model.generate(
input_ids,
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=20,
)
is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe"
report = ""
if not is_safe:
true_false_ids = tokenizer("true false").input_ids[:2]
keys = ["toxicity", "hate", "identity", "violence", "physical", "sexual", "profanity", "biased"]
scores = {}
for k, i in zip(keys, range(3,20,2)):
scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5)
report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n"
report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n"
return "Salesforce Content Safety Flan T5 Base", is_safe, report
def get_total_length(self, data):
prefix = "Is the <Text> field safe or unsafe "
input_sample = "<Text> {output} <Context> ".format(**data[0])
return len(self.tokenizer(prefix + input_sample)["input_ids"])
# Class for performing safety checks using Azure Content Safety service
class AzureSaftyChecker(object):
def __init__(self):
try:
from azure.ai.contentsafety import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
key = os.environ["CONTENT_SAFETY_KEY"]
endpoint = os.environ["CONTENT_SAFETY_ENDPOINT"]
except ImportError:
raise Exception(
"Could not import required package azure-ai-contentsafety. Install with: pip install azure-ai-contentsafety"
)
except KeyError:
raise Exception(
"Environment variables not set. Please set CONTENT_SAFETY_KEY and CONTENT_SAFETY_ENDPOINT."
)
self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
def __call__(self, output_text, **kwargs):
from azure.core.exceptions import HttpResponseError
from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory
print(len(output_text))
if len(output_text) > 1000:
raise Exception("Input length to safety check is too long (>1000).")
categories = [
TextCategory.VIOLENCE,
TextCategory.SELF_HARM,
TextCategory.SEXUAL,
TextCategory.HATE,
]
request = AnalyzeTextOptions(text=output_text, categories=categories)
try:
response = self.client.analyze_text(request)
except HttpResponseError as e:
print("Analyze text failed.")
if e.error:
print(f"Error code: {e.error.code}")
print(f"Error message: {e.error.message}")
raise
print(e)
raise e
levels = {0: "Safe", 2: "Low", 4: "Medium", 6: "High"}
severities = [
getattr(response, c.name.lower() + "_result").severity for c in categories
]
DEFAULT_LEVELS = [0, 0, 0, 0]
is_safe = all([s <= l for s, l in zip(severities, DEFAULT_LEVELS)])
report = ""
if not is_safe:
report = "|" + "|".join(f"{c.name:^10}" for c in categories) + "|\n"
report += "|" + "|".join(f"{levels[s]:^10}" for s in severities) + "|\n"
return "Azure Content Saftey API", is_safe, report
class LlamaGuardSafetyChecker(object):
def __init__(self):
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
model_id = "meta-llama/LlamaGuard-7b"
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
def __call__(self, output_text, **kwargs):
agent_type = kwargs.get('agent_type', AgentType.USER)
user_prompt = kwargs.get('user_prompt', "")
model_prompt = output_text.strip()
if(agent_type == AgentType.AGENT):
if user_prompt == "":
print("empty user prompt for agent check, returning unsafe")
return "Llama Guard", False, "Missing user_prompt from Agent response check"
else:
model_prompt = model_prompt.replace(user_prompt, "")
user_prompt = f"User: {user_prompt}"
agent_prompt = f"Agent: {model_prompt}"
chat = [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": agent_prompt},
]
else:
chat = [
{"role": "user", "content": model_prompt},
]
input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to("cuda")
prompt_len = input_ids.shape[-1]
output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
splitted_result = result.split("\n")[0];
is_safe = splitted_result == "safe"
report = result
return "Llama Guard", is_safe, report
# Function to load the PeftModel for performance optimization
# Function to determine which safety checker to use based on the options selected
def get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics,
enable_salesforce_content_safety,
enable_llamaguard_content_safety):
safety_checker = []
if enable_azure_content_safety:
safety_checker.append(AzureSaftyChecker())
if enable_sensitive_topics:
safety_checker.append(AuditNLGSensitiveTopics())
if enable_salesforce_content_safety:
safety_checker.append(SalesforceSafetyChecker())
if enable_llamaguard_content_safety:
safety_checker.append(LlamaGuardSafetyChecker())
return safety_checker
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from llama_recipes.model_checkpointing.checkpoint_handler import (
load_model_checkpoint,
save_model_checkpoint,
load_optimizer_checkpoint,
save_optimizer_checkpoint,
save_model_and_optimizer_sharded,
load_model_sharded,
load_sharded_model_single_gpu
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from pathlib import Path
from datetime import datetime
import torch
import time
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType,
FullStateDictConfig, # general model non-sharded, non-flattened params
LocalStateDictConfig, # flattened params, usable only by FSDP
# ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
)
from torch.distributed._shard.checkpoint import (
FileSystemReader,
FileSystemWriter,
save_state_dict,
load_state_dict,
)
from torch.distributed.checkpoint.default_planner import (
DefaultSavePlanner,
DefaultLoadPlanner,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
import torch.distributed._shard.checkpoint as dist_cp
import torch.distributed as dist
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
print(f"--> current date and time of run = {date_of_run}")
return date_of_run
# create singleton saving policies to avoid making over and over
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
def load_model_sharded(model, rank, cfg):
# torch.manual_seed(103)
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
load_dir = Path.cwd() / folder_name
if not load_dir.exists():
if rank == 0:
print(f"No sharded_state_dict checkpoint directory found...skipping")
return
if rank == 0:
print(f"loading model from model path: {load_dir} ")
reader = FileSystemReader(load_dir)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
checkpoint = {"model": model.state_dict()}
if rank == 0:
ck = checkpoint.keys()
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
dist_cp.load_state_dict(
state_dict=checkpoint,
storage_reader=reader,
)
if rank == 0:
print(f"checkpoint after load_state_dict()")
ck = checkpoint.keys()
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
model.load_state_dict(checkpoint["model"])
if rank == 0:
print(f"Sharded state checkpoint loaded from {load_dir}")
def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
"""save model and optimizer via sharded_state_dict to save_dir"""
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
if rank == 0:
print(f"Saving model to {save_dir}")
distributed_writer = dist_cp.FileSystemWriter(
save_dir,
)
t0 = time.perf_counter()
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {"model": model.state_dict()}
if optim is not None:
state_dict["optim"] = FSDP.optim_state_dict(model, optim)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=distributed_writer,
planner=DefaultSavePlanner(),
)
dist.barrier()
t1 = time.perf_counter()
if rank == 0:
print(f"Sharded state checkpoint saved to {save_dir}")
print(
f"Checkpoint Time = {t1-t0:.4f}\n"
)
def save_model_checkpoint(
model,
optimizer,
rank,
cfg,
epoch=1,
):
"""saving model via rank0 cpu streaming and full_state_dict"""
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
):
cpu_state = model.state_dict()
print(f"saving process: rank {rank} done w model state_dict\n")
if rank == 0:
print(f"--> saving model ...")
# create save path
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
save_dir.mkdir(parents=True, exist_ok=True)
save_name = cfg.model_name + "-" + str(epoch) + ".pt"
save_full_path = str(save_dir) + "/" + save_name
# save model
torch.save(cpu_state, save_full_path)
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
def load_model_checkpoint(model, rank, cfg):
"""load local checkpoint to rank0 cpu
must be called * before * passing to FSDP"""
if rank != 0:
return
# where is the checkpoint at...
full_state_dict_model_path = (
Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename
)
# is it present...
if not full_state_dict_model_path.is_file():
print(
f"model checkpoint {full_state_dict_model_path} not present. Returning..."
)
return
model_checkpoint = torch.load(full_state_dict_model_path)
# integrate into loaded model
model.load_state_dict(model_checkpoint)
print(f"model checkpoint loaded to rank0 cpu")
def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
"""save optimizer state via full state dict"""
print(f"--> optim state call on rank {rank}\n")
# pull all sharded optimizer states to rank0 cpu...
optim_state = FSDP.full_optim_state_dict(model, optimizer)
print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
if rank == 0:
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
save_dir.mkdir(parents=True, exist_ok=True)
opt_save_name = (
"optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
)
opt_save_full_path = save_dir / opt_save_name
print(f"--> saving optimizer state...")
torch.save(optim_state, opt_save_full_path)
print(f"--> saved {opt_save_full_path} to disk")
def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
"""load an fsdp optimizer full_state checkpoint using scatter method
this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
"""
if not optimizer_checkpoint_path.is_file():
print(
f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
)
return
full_osd = None
if rank == 0:
full_osd = torch.load(optimizer_checkpoint_path)
# called from all ranks, though only rank0 has a valid param for full_osd
sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
print(f"optimizer shard loaded on rank {rank}")
def load_sharded_model_single_gpu(model,model_path):
reader = FileSystemReader(model_path)
state_dict = {
"model": model.state_dict()
}
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader= FileSystemReader(model_path),
no_dist=True,
)
model.load_state_dict(state_dict["model"])
print(f"Sharded state checkpoint loaded from {model_path}")
return model
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from llama_recipes.policies.mixed_precision import *
from llama_recipes.policies.wrapping import *
from llama_recipes.policies.activation_checkpointing_functions import apply_fsdp_checkpointing
from llama_recipes.policies.anyprecision_optimizer import AnyPrecisionAdamW
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from functools import partial
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
def apply_fsdp_checkpointing(model):
"""apply activation checkpointing to model
returns None as model is updated directly
"""
print(f"--> applying fsdp activation checkpointing...")
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# AnyPrecisionAdamW: a flexible precision AdamW optimizer
# with optional Kahan summation for high precision weight updates.
# Allows direct control over momentum, variance and auxiliary compensation
# buffer dtypes.
# Optional Kahan summation is used to offset precision reduction for
# the weight updates. This allows full training in BFloat16 (equal or
# better than FP32 results in many cases) due to high precision weight upates.
import torch
from torch.optim.optimizer import Optimizer
class AnyPrecisionAdamW(Optimizer):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
use_kahan_summation=False,
momentum_dtype=torch.bfloat16,
variance_dtype=torch.bfloat16,
compensation_buffer_dtype=torch.bfloat16,
):
"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
# Any Precision specific
use_kahan_summation = creates auxiliary buffer to ensure high precision
model param updates (default: False)
momentum_dtype = dtype for momentum (default: BFloat32)
variance_dtype = dtype for uncentered variance (default: BFloat16)
compensation_buffer_dtype = dtype for Kahan summation
buffer (default: BFloat16)
# Usage
This optimizer implements optimizer states, and Kahan summation
for high precision updates, all in user controlled dtypes.
Defaults are variance in BF16, Momentum in FP32.
This can be run in FSDP mixed precision, amp, or full precision,
depending on what training pipeline you wish to work with.
Setting to use_kahan_summation = False, and changing momentum and
variance dtypes to FP32, reverts this to a standard AdamW optimizer.
"""
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
use_kahan_summation=use_kahan_summation,
momentum_dtype=momentum_dtype,
variance_dtype=variance_dtype,
compensation_buffer_dtype=compensation_buffer_dtype,
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
if closure is not None:
with torch.enable_grad():
# to fix linter, we do not keep the returned loss for use atm.
closure()
for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
weight_decay = group["weight_decay"]
eps = group["eps"]
use_kahan_summation = group["use_kahan_summation"]
momentum_dtype = group["momentum_dtype"]
variance_dtype = group["variance_dtype"]
compensation_buffer_dtype = group["compensation_buffer_dtype"]
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError(
"AnyPrecisionAdamW does not support sparse gradients"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0)
# momentum - EMA of gradient values
state["exp_avg"] = torch.zeros_like(
p,
dtype=momentum_dtype,
)
# variance uncentered - EMA of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p,
dtype=variance_dtype,
)
# optional Kahan summation - accumulated error tracker
if use_kahan_summation:
state["compensation"] = torch.zeros_like(
p,
dtype=compensation_buffer_dtype,
)
# main processing -------------------------
# update the steps for each param group update
state["step"] += 1
step = state["step"]
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
grad = p.grad
# weight decay, AdamW style
if weight_decay:
p.data.mul_(1 - lr * weight_decay)
# update momentum
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# update uncentered variance
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# adjust using bias1
bias_correction1 = 1 - beta1**step
step_size = lr / bias_correction1
# adjust using bias2
denom_correction = (1 - beta2**step) ** 0.5 # avoids math import
centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(
eps, alpha=1
)
# lr update to compensation
if use_kahan_summation:
compensation = state["compensation"]
compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
# update weights with compensation (Kahan summation)
# save error back to compensation for next iteration
temp_buffer = p.detach().clone()
p.data.add_(compensation)
compensation.add_(temp_buffer.sub_(p.data))
else:
# usual AdamW updates
p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import torch
from torch.distributed.fsdp import (
MixedPrecision,
)
# requires grad scaler in main loop
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
cast_forward_inputs=True,
)
bfSixteen_mixed = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
fp32_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment