Commit 5eaaba41 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add in 0524

parents
Pipeline #1017 failed with stages
in 0 seconds
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import functools
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
)
def get_size_policy(min_params=1e8):
num_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=min_params
)
return num_wrap_policy
def get_llama_wrapper():
"""we register our main layer class and use the fsdp transformer wrapping policy
ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers
"""
# ==== use new transformer wrapper
llama_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
LlamaDecoderLayer,
},
)
return llama_auto_wrap_policy
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import json
import os
from typing import List, Union
import fire
import torch
from tqdm import tqdm
from transformers import LlamaForCausalLM # @manual
NUM_SHARDS = {
"7B": 1,
"13B": 2,
"34B": 4,
"30B": 4,
"65B": 8,
"70B": 8,
}
def write_model(model_path, model_size, output_base_path):
dtype = torch.bfloat16
params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
num_shards = NUM_SHARDS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = (
1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
).to(dtype)
if "n_kv_heads" in params:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim
model = LlamaForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
loaded = model.state_dict()
# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return (
w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
.transpose(1, 2)
.reshape(dim1, dim2)
)
state_dict = [{} for _ in range(num_shards)]
def insert(name: str, tensor: Union[List, torch.Tensor]):
for i in range(num_shards):
state_dict[i][name] = (
tensor[i].clone() if isinstance(tensor, list) else tensor
)
def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
tensors = tensor.chunk(num_shards, dim=dim)
for i, tensor in enumerate(tensors):
state_dict[i][name] = tensor.clone()
insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
insert("norm.weight", loaded["model.norm.weight"])
insert_chunk("output.weight", loaded["lm_head.weight"], 0)
for layer_i in tqdm(range(n_layers), desc="Converting layers"):
ts = (
permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
.view(n_heads_per_shard * num_shards, dims_per_head, dim)
.chunk(num_shards, dim=0)
)
insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
ts = (
permute(
loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
num_key_value_heads,
key_value_dim,
dim,
)
.view(num_local_key_value_heads * num_shards, dims_per_head, dim)
.chunk(num_shards, dim=0)
)
insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
ts = (
loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
.view(num_local_key_value_heads * num_shards, dims_per_head, dim)
.chunk(num_shards, dim=0)
)
insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
insert_chunk(
f"layers.{layer_i}.attention.wo.weight",
loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
1,
)
insert_chunk(
f"layers.{layer_i}.feed_forward.w1.weight",
loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
0,
)
insert_chunk(
f"layers.{layer_i}.feed_forward.w2.weight",
loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
1,
)
insert_chunk(
f"layers.{layer_i}.feed_forward.w3.weight",
loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
0,
)
insert(
f"layers.{layer_i}.attention_norm.weight",
loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
)
insert(
f"layers.{layer_i}.ffn_norm.weight",
loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
)
insert("rope.freqs", inv_freq)
for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
torch.save(
state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
)
def main(
model_path: str,
model_size: str,
output_dir: str,
):
"""Convert llama weights from huggingface format to consolidated format.
params:
model_path: model name or path to the model directory.
model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.
output_dir: directory to save Llama weights, should contains params.json.
"""
assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
params_path = os.path.join(output_dir, "params.json")
assert os.path.isfile(params_path), f"{params_path} does not exist"
write_model(model_path, model_size, output_dir)
if __name__ == "__main__":
fire.Fire(main)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from llama_recipes.utils.memory_utils import MemoryTrace
from llama_recipes.utils.dataset_utils import *
from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh
from llama_recipes.utils.train_utils import *
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import inspect
from dataclasses import asdict
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from peft import (
LoraConfig,
AdaptionPromptConfig,
PrefixTuningConfig,
)
from transformers import default_data_collator
from transformers.data import DataCollatorForSeq2Seq
from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
from llama_recipes.utils.dataset_utils import DATASET_PREPROC
def update_config(config, **kwargs):
if isinstance(config, (tuple, list)):
for c in config:
update_config(c, **kwargs)
else:
for k, v in kwargs.items():
if hasattr(config, k):
setattr(config, k, v)
elif "." in k:
# allow --some_config.some_param=True
config_name, param_name = k.split(".")
if type(config).__name__ == config_name:
if hasattr(config, param_name):
setattr(config, param_name, v)
else:
# In case of specialized config we can warn user
print(f"Warning: {config_name} does not accept parameter: {k}")
elif isinstance(config, train_config):
print(f"Warning: unknown parameter {k}")
def generate_peft_config(train_config, kwargs):
configs = (lora_config, llama_adapter_config, prefix_config)
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
names = tuple(c.__name__.rstrip("_config") for c in configs)
if train_config.peft_method not in names:
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
if train_config.peft_method == "prefix":
raise RuntimeError("PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)")
if train_config.enable_fsdp and train_config.peft_method == "llama_adapter":
raise RuntimeError("Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)")
config = configs[names.index(train_config.peft_method)]()
update_config(config, **kwargs)
params = asdict(config)
peft_config = peft_configs[names.index(train_config.peft_method)](**params)
return peft_config
def generate_dataset_config(train_config, kwargs):
names = tuple(DATASET_PREPROC.keys())
assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
update_config(dataset_config, **kwargs)
return dataset_config
def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
if train_config.batching_strategy == "padding":
if train_config.enable_fsdp:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
)
else:
kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
elif train_config.batching_strategy == "packing":
if train_config.enable_fsdp:
kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
drop_last=True,
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
else:
raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
return kwargs
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import importlib
from functools import partial
from pathlib import Path
import torch
from llama_recipes.datasets import (
get_grammar_dataset,
get_alpaca_dataset,
get_samsum_dataset,
)
def load_module_from_py_file(py_file: str) -> object:
"""
This method loads a module from a py file which is not in the Python path
"""
module_name = Path(py_file).name
loader = importlib.machinery.SourceFileLoader(module_name, py_file)
spec = importlib.util.spec_from_loader(module_name, loader)
module = importlib.util.module_from_spec(spec)
loader.exec_module(module)
return module
def get_custom_dataset(dataset_config, tokenizer, split: str):
if ":" in dataset_config.file:
module_path, func_name = dataset_config.file.split(":")
else:
module_path, func_name = dataset_config.file, "get_custom_dataset"
if not module_path.endswith(".py"):
raise ValueError(f"Dataset file {module_path} is not a .py file.")
module_path = Path(module_path)
if not module_path.is_file():
raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
module = load_module_from_py_file(module_path.as_posix())
try:
return getattr(module, func_name)(dataset_config, tokenizer, split)
except AttributeError as e:
print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
raise e
DATASET_PREPROC = {
"alpaca_dataset": partial(get_alpaca_dataset),
"grammar_dataset": get_grammar_dataset,
"samsum_dataset": get_samsum_dataset,
"custom_dataset": get_custom_dataset,
}
def get_preprocessed_dataset(
tokenizer, dataset_config, split: str = "train"
) -> torch.utils.data.Dataset:
if not dataset_config.dataset in DATASET_PREPROC:
raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
def get_split():
return (
dataset_config.train_split
if split == "train"
else dataset_config.test_split
)
return DATASET_PREPROC[dataset_config.dataset](
dataset_config,
tokenizer,
get_split(),
)
from typing import Any, Dict, List, Optional, Union
import time
import torch
from torch.utils.flop_counter import FlopCounterMode
class FlopMeasure(FlopCounterMode):
"""
``FlopMeasure`` is a customized context manager that counts the number of
flops within its context. It is based on ``FlopCounterMode`` with additional start_counting() and stop_counting() function so that the flop counting
will only start after the warmup stage.
It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
Example usage
.. code-block:: python
model = ...
flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3)
for batch in enumerate(dataloader):
with flop_counter:
model(batch)
flop_counter.step()
"""
def __init__(
self,
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
depth: int = 2,
display: bool = True,
custom_mapping: Dict[Any, Any] = None,
rank=None,
warmup_step: int = 3,
):
super().__init__(mods, depth, display, custom_mapping)
self.rank = rank
self.warmup_step = warmup_step
self.start_time = 0
self.end_time = 0
def step(self):
# decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1.
if self.warmup_step >= 0:
self.warmup_step -= 1
if self.warmup_step == 0 and self.start_time == 0:
self.start_time = time.time()
elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0:
self.end_time = time.time()
def __enter__(self):
if self.warmup_step == 0:
self.start_time = time.time()
super().__enter__()
return self
def is_done(self):
return self.warmup_step == -1
def get_total_flops(self):
return super().get_total_flops()
def get_flops_per_sec(self):
if self.start_time == 0 or self.end_time == 0:
print("Warning: flop count did not finish correctly")
return 0
return super().get_total_flops()/ (self.end_time - self.start_time)
def get_table(self, depth=2):
return super().get_table(depth)
def __exit__(self, *args):
if self.get_total_flops() == 0:
print(
"Warning: did not record any flops this time. Skipping the flop report"
)
else:
if self.display:
if self.rank is None or self.rank == 0:
print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time))
print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12))
print("The tflop_count table is below:")
print(self.get_table(self.depth))
# Disable the display feature so that we don't print the table again
self.display = False
super().__exit__(*args)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
# when warmup_step is 0, count the flops and return the original output
if self.warmup_step == 0:
return super().__torch_dispatch__(func, types, args, kwargs)
# otherwise, just return the original output
return func(*args, **kwargs)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from torch.distributed._tensor.device_mesh import init_device_mesh
import os
def fsdp_auto_wrap_policy(model, transformer_layer_name):
import functools
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
transformer_layer_name,
),
)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
return auto_wrap_policy
def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
"""
Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training.
This function requires explicit sizes for replica and sharding groups to accommodate models
whose GPU fit is unknown, providing flexibility in distributed training setups.
Args:
replica_group_size (int): The size of each replica group. Must be provided to ensure
the model fits within the available resources.
sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to
ensure the correct distribution of model parameters.
device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda"
with the local rank as the device index.
Returns:
A device mesh object compatible with FSDP.
Raises:
ValueError: If replica_group_size or sharding_group_size are not provided, or if the
world size is not evenly divisible by the sharding group size.
RuntimeError: If a valid device mesh cannot be created.
Usage:
If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then:
Sharding_Group_Size = 4
Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups
>>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size)
>>> sharded_model = FSDP(model, device_mesh=device_mesh, ...)
"""
if replica_group_size is None or sharding_group_size is None:
raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
device = device or f"cuda"
if world_size % sharding_group_size != 0:
raise ValueError(f"World size {world_size} is not evenly divisible by "
f"sharding group size {sharding_group_size}.")
if (world_size // sharding_group_size) % replica_group_size != 0:
raise ValueError(f"The calculated number of replica groups is not evenly divisible by "
f"replica_group_size {replica_group_size}.")
device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size))
if device_mesh is None:
raise RuntimeError("Failed to create a valid device mesh.")
return device_mesh
# Convert Hugging Face llama weights to official llama consolidated format
This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package.
## Step 0: Convert to consolidated format
- Create an output directory for the converted weights, such as `test70B`.
- Copy file params.json from the official llama download into that directory.
- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
```
python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
```
## Step 1: Run inference
Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
```
torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
```
For validation, please compare the converted weights with official llama 2 weights
```
python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
```
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import gc
import glob
import os
import sys
import torch
import tqdm
def main() -> None:
"""Compare two llama checkpoint directories"""
one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth")))
two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth")))
assert len(one_files) == len(
two_files
), "One directory has {} files while another has {} files.".format(
len(one_files), len(two_files)
)
deltas = []
for i in tqdm.trange(len(one_files), desc="Comparing shards"):
one = torch.load(one_files[i])
two = torch.load(two_files[i])
assert len(one) == len(
two
), "shard should have the same length: {} != {}".format(len(one), len(two))
for _, (v, w) in enumerate(zip(one.items(), two.items())):
assert v[0] == w[0], "{} != {}".format(v[0], w[0])
assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
v[0], v[1].shape, w[1].shape
)
delta = (v[1] - w[1]).abs().max().item()
deltas.append((i, v[0], delta))
del one
del two
gc.collect()
deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
print("Top 10 largest deltas:")
for i, k, v in deltas[:10]:
print(f" shard {i} {k}: {v}")
if __name__ == "__main__":
main()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import gc
import psutil
import threading
import torch
from accelerate.utils import is_xpu_available
def byte2gb(x):
return int(x / 2**30)
# This context manager is used to track the peak memory usage of the process
class MemoryTrace:
def __enter__(self):
gc.collect()
if is_xpu_available():
torch.xpu.empty_cache()
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.xpu.memory_allocated())
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.cuda.memory_allocated())
self.process = psutil.Process()
self.cpu_begin = byte2gb(self.cpu_mem_used())
self.peak_monitoring = True
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
return self
def cpu_mem_used(self):
"""get resident set size memory for the current process"""
return self.process.memory_info().rss
def peak_monitor_func(self):
self.cpu_peak = -1
while True:
self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
# can't sleep or will not catch the peak right (this comment is here on purpose)
# time.sleep(0.001) # 1msec
if not self.peak_monitoring:
break
def __exit__(self, *exc):
self.peak_monitoring = False
gc.collect()
if is_xpu_available():
torch.xpu.empty_cache()
self.end = byte2gb(torch.xpu.memory_allocated())
self.peak = byte2gb(torch.xpu.max_memory_allocated())
xpu_info = torch.xpu.memory_stats()
self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
self.malloc_retries = xpu_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
self.m_ooms = xpu_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
elif torch.cuda.is_available():
torch.cuda.empty_cache()
self.end = byte2gb(torch.cuda.memory_allocated())
self.peak = byte2gb(torch.cuda.max_memory_allocated())
cuda_info = torch.cuda.memory_stats()
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.malloc_retries = cuda_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.m_ooms = cuda_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
self.cpu_end = self.cpu_mem_used()
self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
def print_stats(self):
device_str = None
if is_xpu_available():
device_str = "XPU"
elif torch.cuda.is_available():
device_str = "CUDA"
if device_str:
print(f"Max {device_str} memory allocated was {self.peak} GB")
print(f"Max {device_str} memory reserved was {self.max_reserved} GB")
print(f"Peak active {device_str} memory was {self.peak_active_gb} GB")
print(f"{device_str} Malloc retries : {self.malloc_retries}")
print(f"CPU Total Peak Memory consumed during the train (max): {self.cpu_peaked + self.cpu_begin} GB")
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import json
import matplotlib.pyplot as plt
import argparse
import os
def plot_metric(data, metric_name, x_label, y_label, title, colors):
plt.figure(figsize=(7, 6))
plt.plot(data[f'train_epoch_{metric_name}'], label=f'Train Epoch {metric_name.capitalize()}', color=colors[0])
plt.plot(data[f'val_epoch_{metric_name}'], label=f'Validation Epoch {metric_name.capitalize()}', color=colors[1])
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(f'Train and Validation Epoch {title}')
plt.legend()
plt.tight_layout()
def plot_single_metric_by_step(data, metric_name, x_label, y_label, title, color):
plt.plot(data[f'{metric_name}'], label=f'{title}', color=color)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
plt.legend()
plt.tight_layout()
def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
plot_single_metric_by_step(data, f'train_step_{metric_name}', x_label, y_label, f'Train Step {metric_name.capitalize()}', colors[0])
plt.subplot(1, 2, 2)
plot_single_metric_by_step(data, f'val_step_{metric_name}', x_label, y_label, f'Validation Step {metric_name.capitalize()}', colors[1])
plt.tight_layout()
def plot_metrics(file_path):
if not os.path.exists(file_path):
print(f"File {file_path} does not exist.")
return
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError:
print("Invalid JSON file.")
return
directory = os.path.dirname(file_path)
filename_prefix = os.path.basename(file_path).split('.')[0]
plot_metric(data, 'loss', 'Epoch', 'Loss', 'Loss', ['b', 'r'])
plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss.png"))
plt.close()
plot_metric(data, 'perplexity', 'Epoch', 'Perplexity', 'Perplexity', ['g', 'm'])
plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity.png"))
plt.close()
plot_metrics_by_step(data, 'loss', 'Step', 'Loss', ['b', 'r'])
plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss_by_step.png"))
plt.close()
plot_metrics_by_step(data, 'perplexity', 'Step', 'Loss', ['g', 'm'])
plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity_by_step.png"))
plt.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Plot metrics from JSON file.')
parser.add_argument('--file_path', required=True, type=str, help='Path to the metrics JSON file.')
args = parser.parse_args()
plot_metrics(args.file_path)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import time
import yaml
from contextlib import nullcontext
from pathlib import Path
from pkg_resources import packaging
from datetime import datetime
import contextlib
import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from tqdm import tqdm
from transformers import LlamaTokenizer
import json
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace
from accelerate.utils import is_xpu_available, is_ccl_available
from llama_recipes.utils.flop_utils import FlopMeasure
def set_tokenizer_params(tokenizer: LlamaTokenizer):
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
@contextlib.contextmanager
def profile(cfg, local_rank=None):
use_profiler: bool = cfg.use_profiler
use_flop_counter: bool = cfg.flop_counter
if use_flop_counter and use_profiler:
raise ValueError("Cannot use both profiler and flop counter")
if use_profiler:
# profiler needs a warmup stage to get the accurate profiling results
wait_step, warmup_step, active_step = 1, 2, 3
min_step = wait_step + warmup_step + active_step + 1
if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
cfg.profiler_dir
),
profile_memory=True,
with_stack=False,
with_flops=True,
record_shapes=True,
) as torch_profiler:
yield torch_profiler
elif use_flop_counter:
if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
yield flop_counter
else:
torch_profiler = contextlib.nullcontext()
yield None
def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
"""
Trains the model on the given dataloader
Args:
model: The model to be trained
train_dataloader: The dataloader containing the training data
optimizer: The optimizer used for training
lr_scheduler: The learning rate scheduler
gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
num_epochs: The number of epochs to train for
local_rank: The rank of the current node in a distributed setting
train_config: The training configuration
eval_dataloader: The dataloader containing the eval data
tokenizer: tokenizer used in the eval for decoding the predicitons
Returns: results dictionary containing average training and validation perplexity and loss
"""
# Create a gradient scaler for fp16
if train_config.use_fp16 and train_config.enable_fsdp:
scaler = ShardedGradScaler()
elif train_config.use_fp16 and not train_config.enable_fsdp:
scaler = torch.cuda.amp.GradScaler()
if train_config.enable_fsdp:
world_size = int(os.environ["WORLD_SIZE"])
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
train_prep = []
train_loss = []
val_prep = []
val_loss =[]
if train_config.save_metrics:
if not os.path.exists(train_config.output_dir):
os.makedirs(train_config.output_dir, exist_ok=True)
metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
train_step_perplexity = []
train_step_loss = []
val_step_loss = []
val_step_perplexity = []
epoch_times = []
checkpoint_times = []
results = {}
best_val_loss = float("inf")
total_train_steps = 0
max_steps_reached = False # Flag to indicate max training steps reached
# Start the training loop
for epoch in range(train_config.num_epochs):
# stop when the maximum number of training steps is reached
if max_steps_reached:
break
epoch_start_time = time.perf_counter()
with MemoryTrace() as memtrace: # track the memory usage
model.train()
total_loss = 0.0
total_length = len(train_dataloader)//gradient_accumulation_steps
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
with profile(train_config,local_rank) as profile_context:
for step, batch in enumerate(train_dataloader):
total_train_steps += 1
# stop when the maximum number of training steps is reached
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
max_steps_reached = True
if not train_config.enable_fsdp or local_rank==0:
print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
break
for key in batch.keys():
if train_config.enable_fsdp:
if is_xpu_available():
batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
else:
batch[key] = batch[key].to(local_rank)
else:
if is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
with autocast():
loss = model(**batch).loss
loss = loss / gradient_accumulation_steps
if train_config.save_metrics:
train_step_loss.append(loss.detach().float().item())
train_step_perplexity.append(float(torch.exp(loss.detach().float())))
total_loss += loss.detach().float()
if train_config.use_fp16:
# if fp16 is enabled, use gradient scaler to handle gradient update
scaler.scale(loss).backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
scaler.unscale_(optimizer)
if train_config.enable_fsdp:
model.clip_grad_norm_(train_config.gradient_clipping_threshold)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
pbar.update(1)
else:
# regular backpropagation when fp16 is not used
loss.backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
if train_config.enable_fsdp:
model.clip_grad_norm_(train_config.gradient_clipping_threshold)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
if train_config.use_profiler or train_config.flop_counter:
profile_context.step()
if train_config.flop_counter and profile_context.is_done():
TFlops = profile_context.get_flops_per_sec() / 1e12
if wandb_run:
if not train_config.enable_fsdp or rank==0:
wandb_run.log({
'train/epoch': epoch + 1,
'train/step': epoch * len(train_dataloader) + step,
'train/loss': loss.detach().float(),
})
pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
if train_config.save_metrics:
save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
pbar.close()
epoch_end_time = time.perf_counter()-epoch_start_time
epoch_times.append(epoch_end_time)
# Reducing total_loss across all devices if there's more than one CUDA device
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
train_epoch_loss = total_loss / len(train_dataloader)
if train_config.enable_fsdp:
train_epoch_loss = train_epoch_loss/world_size
train_perplexity = torch.exp(train_epoch_loss)
train_prep.append(float(train_perplexity))
train_loss.append(float(train_epoch_loss))
if not train_config.enable_fsdp or rank==0:
memtrace.print_stats()
# Update the learning rate as needed
lr_scheduler.step()
if train_config.run_validation:
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
if train_config.save_metrics:
val_step_loss.extend(temp_val_loss)
val_step_perplexity.extend(temp_step_perplexity)
checkpoint_start_time = time.perf_counter()
if train_config.save_model and eval_epoch_loss < best_val_loss:
if train_config.enable_fsdp:
dist.barrier()
if train_config.use_peft:
if train_config.enable_fsdp:
if rank==0:
print(f"we are about to save the PEFT modules")
else:
print(f"we are about to save the PEFT modules")
model.save_pretrained(train_config.output_dir)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
save_model_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
print("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config)
if train_config.save_optimizer:
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
print("=====================================================")
if not train_config.use_peft and train_config.save_optimizer:
save_optimizer_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
print("=====================================================")
if train_config.enable_fsdp:
dist.barrier()
checkpoint_end_time = time.perf_counter() - checkpoint_start_time
checkpoint_times.append(checkpoint_end_time)
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
if train_config.enable_fsdp:
if rank==0:
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
else:
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
val_loss.append(float(best_val_loss))
val_prep.append(float(eval_ppl))
if train_config.enable_fsdp:
if rank==0:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
else:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
# Saving the results every epoch to plot later
if train_config.save_metrics:
save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
avg_train_prep = sum(train_prep)/len(train_prep)
avg_train_loss = sum(train_loss)/len(train_loss)
if train_config.run_validation:
avg_eval_prep = sum(val_prep)/len(val_prep)
avg_eval_loss = sum(val_loss)/len(val_loss)
results['avg_train_prep'] = avg_train_prep
results['avg_train_loss'] = avg_train_loss
if train_config.run_validation:
results['avg_eval_prep'] = avg_eval_prep
results['avg_eval_loss'] = avg_eval_loss
results["avg_epoch_time"] = avg_epoch_time
results["avg_checkpoint_time"] = avg_checkpoint_time
if train_config.save_metrics:
results["metrics_filename"] = metrics_filename
if train_config.flop_counter:
results["model_tflops"]= TFlops
#saving the training params including fsdp setting for reference.
if train_config.enable_fsdp and not train_config.use_peft and rank==0:
save_train_params(train_config, fsdp_config, rank)
return results
def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
"""
Evaluates the model on the given dataloader
Args:
model: The model to evaluate
eval_dataloader: The dataloader containing the evaluation data
local_rank: The rank of the current node in a distributed setting
tokenizer: The tokenizer used to decode predictions
Returns: eval_ppl, eval_epoch_loss
"""
if train_config.enable_fsdp:
world_size = int(os.environ["WORLD_SIZE"])
model.eval()
eval_preds = []
val_step_loss = []
val_step_perplexity = []
eval_loss = 0.0 # Initialize evaluation loss
total_eval_steps = 0
with MemoryTrace() as memtrace:
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
total_eval_steps += 1
# stop when the maximum number of eval steps is reached
if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
if not train_config.enable_fsdp or local_rank==0:
print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
break
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
else:
if is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
# Ensure no gradients are computed for this scope to save memory
with torch.no_grad():
# Forward pass and compute loss
outputs = model(**batch)
loss = outputs.loss
if train_config.save_metrics:
val_step_loss.append(loss.detach().float().item())
val_step_perplexity.append(float(torch.exp(loss.detach().float())))
eval_loss += loss.detach().float()
# Decode predictions and add to evaluation predictions list
preds = torch.argmax(outputs.logits, -1)
eval_preds.extend(
tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
)
# If there's more than one CUDA device, reduce evaluation loss across all devices
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
# Compute average loss and perplexity
eval_epoch_loss = eval_loss / len(eval_dataloader)
if train_config.enable_fsdp:
eval_epoch_loss = eval_epoch_loss/world_size
eval_ppl = torch.exp(eval_epoch_loss)
# Print evaluation metrics
if train_config.enable_fsdp:
if local_rank==0:
print(f" {eval_ppl=} {eval_epoch_loss=}")
else:
print(f" {eval_ppl=} {eval_epoch_loss=}")
if wandb_run:
wandb_run.log({
'eval/perplexity': eval_ppl,
'eval/loss': eval_epoch_loss,
}, commit=False)
return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
def freeze_transformer_layers(model, num_layer):
for i, layer in enumerate(model.model.layers):
if i < num_layer:
for param in layer.parameters():
param.requires_grad = False
def check_frozen_layers_peft_model(model):
for i, layer in enumerate(model.base_model.model.model.layers):
for name, param in layer.named_parameters():
print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
def setup():
"""Initialize the process group for distributed training"""
if is_ccl_available():
# distributed training on xpus
dist.init_process_group("ccl")
else:
dist.init_process_group("nccl")
def setup_environ_flags(rank):
"""Set environment flags for debugging purposes"""
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
# Note this is only availble in PyTorch Nighlies (as of July 30 2023)
# os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
if rank == 0:
print(f"--> Running with torch dist debug set to detail")
def cleanup():
"""Clean up the process group after training"""
dist.destroy_process_group()
def clear_gpu_cache(rank=None):
"""Clear the GPU cache for all ranks"""
if rank == 0:
print(f"Clearing GPU cache for all ranks")
if is_xpu_available():
torch.xpu_empty_cache()
else:
torch.cuda.empty_cache()
def get_parameter_dtypes(model):
"""Get the data types of model parameters"""
parameter_dtypes = {}
for name, parameter in model.named_parameters():
parameter_dtypes[name] = parameter.dtype
return parameter_dtypes
def print_model_size(model, config, rank: int = 0) -> None:
"""
Print model name, the number of trainable parameters and initialization time.
Args:
model: The PyTorch model.
model_name (str): Name of the model.
init_time_start (float): Initialization start time.
init_time_end (float): Initialization end time.
rank (int, optional): Current process's rank. Defaults to 0.
"""
if rank == 0:
print(f"--> Model {config.model_name}")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping"""
verify_bfloat_support = ((
torch.version.cuda
and torch.cuda.is_bf16_supported()
and packaging.version.parse(torch.version.cuda).release >= (11, 0)
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
) or
(is_xpu_available()))
mixed_precision_policy = None
wrapping_policy = None
# Mixed precision
if cfg.mixed_precision:
bf16_ready = verify_bfloat_support
if bf16_ready and not cfg.use_fp16:
mixed_precision_policy = bfSixteen
if rank == 0:
print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
elif cfg.use_fp16:
mixed_precision_policy = fpSixteen
if rank == 0:
print(f"FP16 enabled")
else:
print(f"bFloat16 support not present. Using FP32, and not mixed precision")
wrapping_policy = get_llama_wrapper()
return mixed_precision_policy, wrapping_policy
def save_train_params(train_config, fsdp_config, rank):
"""
This function saves the train_config and FSDP config into a train_params.yaml.
This will be used by converter script in the inference folder to fetch the HF model name or path.
It also would be hepful as a log for future references.
"""
# Convert the train_config and fsdp_config objects to dictionaries,
# converting all values to strings to ensure they can be serialized into a YAML file
train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
# Merge the two dictionaries into one
train_params_dict = {**train_config_dict, **fsdp_config_dict}
# Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
folder_name = (
train_config.dist_checkpoint_root_folder
+ "/"
+ train_config.dist_checkpoint_folder
+ "-"
+ train_config.model_name
)
save_dir = Path.cwd() / folder_name
# If the directory does not exist, create it
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Convert the dictionary to a YAML string
config_yaml = yaml.dump(train_params_dict, indent=4)
file_name = os.path.join(save_dir,'train_params.yaml')
# Check if there's a directory with the same name as the file
if os.path.isdir(file_name):
print(f"Error: {file_name} is a directory, not a file.")
else:
# Write the YAML string to the file
with open(file_name, 'w') as f:
f.write(config_yaml)
if rank==0:
print(f"training params are saved in {file_name}")
def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
metrics_data = {
"train_step_loss": train_step_loss,
"train_epoch_loss": train_epoch_loss,
"train_step_perplexity": train_step_ppl,
"train_epoch_perplexity": train_epoch_ppl,
"val_step_loss": val_step_loss,
"val_epoch_loss": val_epoch_loss,
"val_step_perplexity": val_step_ppl,
"val_epoch_perplexity": val_epoch_ppl
}
with open(output_filename, "w") as f:
json.dump(metrics_data, f)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import pytest
from transformers import AutoTokenizer
ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3-8B"]
@pytest.fixture(params=LLAMA_VERSIONS)
def llama_version(request):
return request.param
@pytest.fixture(scope="module")
def llama_tokenizer(request):
return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
@pytest.fixture
def setup_tokenizer(llama_tokenizer, llama_version):
def _helper(tokenizer_mock):
#Align with Llama 2 tokenizer
tokenizer_mock.from_pretrained.return_value = llama_tokenizer[llama_version]
return _helper
def pytest_addoption(parser):
parser.addoption(
"--unskip-missing-tokenizer",
action="store_true",
default=False, help="disable skip missing tokenizer")
def pytest_configure(config):
config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable")
def pytest_collection_modifyitems(config, items):
if config.getoption("--unskip-missing-tokenizer"):
return
try:
AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer_available = True
except OSError:
tokenizer_available = False
skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
for item in items:
if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:
item.add_marker(skip_missing_tokenizer)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import pytest
from unittest.mock import patch
from transformers import LlamaTokenizer
EXPECTED_RESULTS={
"meta-llama/Llama-2-7b-hf":{
"example_1": "[INST] Who made Berlin [/INST] dunno",
"example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
},
"meta-llama/Meta-Llama-3-8B":{
"example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
"example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
},
}
def check_padded_entry(batch, tokenizer):
seq_len = sum(batch["attention_mask"][0])
assert seq_len < len(batch["attention_mask"][0])
if tokenizer.vocab_size >= 128000:
END_OF_TEXT_ID = 128009
else:
END_OF_TEXT_ID = tokenizer.eos_token_id
assert batch["labels"][0][0] == -100
assert batch["labels"][0][seq_len-1] == END_OF_TEXT_ID
assert batch["labels"][0][-1] == -100
assert batch["input_ids"][0][0] == tokenizer.bos_token_id
assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
@pytest.mark.skip(reason="Flakey due to random dataset order @todo fix order")
@pytest.mark.skip_missing_tokenizer
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.AutoTokenizer')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
from llama_recipes.finetuning import main
setup_tokenizer(tokenizer)
skip_special_tokens = llama_version == "meta-llama/Llama-2-7b-hf"
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
kwargs = {
"dataset": "custom_dataset",
"model_name": llama_version,
"custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py",
"custom_dataset.train_split": "validation",
"batch_size_training": 2,
"val_batch_size": 4,
"use_peft": False,
"batching_strategy": "padding"
}
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]
tokenizer = args[3]
assert len(train_dataloader) == 1120
assert len(eval_dataloader) == 1120 //2
it = iter(eval_dataloader)
batch = next(it)
STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
assert batch["input_ids"].size(0) == 4
assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
check_padded_entry(batch, tokenizer)
it = iter(train_dataloader)
next(it)
batch = next(it)
STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
assert batch["input_ids"].size(0) == 2
assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
check_padded_entry(batch, tokenizer)
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker, llama_version):
from llama_recipes.finetuning import main
tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
kwargs = {
"dataset": "custom_dataset",
"custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py:get_unknown_dataset",
"batch_size_training": 1,
"use_peft": False,
}
with pytest.raises(AttributeError):
main(**kwargs)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import pytest
from unittest.mock import patch
EXPECTED_RESULTS = {
"meta-llama/Llama-2-7b-hf":{
"label": 1152,
"pos": 31,
},
"meta-llama/Meta-Llama-3-8B":{
"label": 40,
"pos": 26,
},
}
@pytest.mark.skip_missing_tokenizer
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.AutoTokenizer')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
from llama_recipes.finetuning import main
setup_tokenizer(tokenizer)
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
BATCH_SIZE = 8
kwargs = {
"model_name": llama_version,
"batch_size_training": BATCH_SIZE,
"val_batch_size": 1,
"use_peft": False,
"dataset": "grammar_dataset",
"batching_strategy": "padding",
}
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]
VAL_SAMPLES = 2988
TRAIN_SAMPLES = 13016
assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
assert len(eval_dataloader) == VAL_SAMPLES
batch = next(iter(train_dataloader))
assert "labels" in batch.keys()
assert "input_ids" in batch.keys()
assert "attention_mask" in batch.keys()
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
token = args[3]
assert batch["input_ids"][0][0] == token.bos_token_id
assert batch["labels"][0][-1] == token.eos_token_id
assert batch["input_ids"][0][-1] == token.eos_token_id
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import pytest
from functools import partial
from unittest.mock import patch
EXPECTED_RESULTS = {
"meta-llama/Llama-2-7b-hf":{
"label": 8432,
"pos": 242,
},
"meta-llama/Meta-Llama-3-8B":{
"label": 2250,
"pos": 211,
},
}
@pytest.mark.skip_missing_tokenizer
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.AutoTokenizer')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
from llama_recipes.finetuning import main
setup_tokenizer(tokenizer)
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
BATCH_SIZE = 8
kwargs = {
"model_name": llama_version,
"batch_size_training": BATCH_SIZE,
"val_batch_size": 1,
"use_peft": False,
"dataset": "samsum_dataset",
"batching_strategy": "padding",
}
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]
token = args[3]
VAL_SAMPLES = 818
TRAIN_SAMPLES = 14732
assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
assert len(eval_dataloader) == VAL_SAMPLES
batch = next(iter(train_dataloader))
assert "labels" in batch.keys()
assert "input_ids" in batch.keys()
assert "attention_mask" in batch.keys()
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
assert batch["input_ids"][0][0] == token.bos_token_id
assert batch["labels"][0][-1] == token.eos_token_id
assert batch["input_ids"][0][-1] == token.eos_token_id
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import pytest
from unittest.mock import patch
EXPECTED_SAMPLE_NUMBER ={
"meta-llama/Llama-2-7b-hf": {
"train": 96,
"eval": 42,
},
"meta-llama/Meta-Llama-3-8B": {
"train": 79,
"eval": 34,
}
}
@pytest.mark.skip_missing_tokenizer
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.AutoTokenizer')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
from llama_recipes.finetuning import main
setup_tokenizer(tokenizer)
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
kwargs = {
"model_name": llama_version,
"batch_size_training": 8,
"val_batch_size": 1,
"use_peft": False,
"dataset": "samsum_dataset",
"batching_strategy": "packing",
}
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]
assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
batch = next(iter(train_dataloader))
assert "labels" in batch.keys()
assert "input_ids" in batch.keys()
assert "attention_mask" in batch.keys()
assert batch["labels"][0].size(0) == 4096
assert batch["input_ids"][0].size(0) == 4096
assert batch["attention_mask"][0].size(0) == 4096
@pytest.mark.skip_missing_tokenizer
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.AutoTokenizer')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
@patch('llama_recipes.finetuning.setup')
@patch('llama_recipes.finetuning.FSDP')
@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
@patch('llama_recipes.utils.config_utils.dist')
def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
import os
from llama_recipes.finetuning import main
setup_tokenizer(tokenizer)
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
rank = 1
os.environ['LOCAL_RANK'] = f'{rank}'
os.environ['RANK'] = f'{rank}'
os.environ['WORLD_SIZE'] = '2'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
kwargs = {
"model_name": llama_version,
"batch_size_training": 8,
"val_batch_size": 1,
"use_peft": False,
"dataset": "samsum_dataset",
"batching_strategy": "packing",
"enable_fsdp": True
}
is_initialized.return_value = True
dist.get_rank.return_value = rank
dist.get_world_size.return_value = 2
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]
assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2
import sys
from pathlib import Path
from typing import List, Literal, TypedDict
from unittest.mock import patch
import pytest
import torch
from llama_recipes.inference.chat_utils import read_dialogs_from_file
ROOT_DIR = Path(__file__).parents[2]
CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
Role = Literal["user", "assistant"]
class Message(TypedDict):
role: Role
content: str
Dialog = List[Message]
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
def _encode_header(message, tokenizer):
tokens = []
tokens.extend(tokenizer.encode("<|start_header_id|>"))
tokens.extend(tokenizer.encode(message["role"]))
tokens.extend(tokenizer.encode("<|end_header_id|>"))
tokens.extend(tokenizer.encode("\n\n"))
return tokens
def _encode_message(message, tokenizer):
tokens = _encode_header(message, tokenizer)
tokens.extend(tokenizer.encode(message["content"].strip()))
tokens.extend(tokenizer.encode("<|eot_id|>"))
return tokens
def _format_dialog(dialog, tokenizer):
tokens = []
tokens.extend(tokenizer.encode("<|begin_of_text|>"))
for msg in dialog:
tokens.extend(_encode_message(msg, tokenizer))
tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
return tokens
def _format_tokens_llama3(dialogs, tokenizer):
return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
def _format_tokens_llama2(dialogs, tokenizer):
prompt_tokens = []
for dialog in dialogs:
if dialog[0]["role"] == "system":
dialog = [
{
"role": dialog[1]["role"],
"content": B_SYS
+ dialog[0]["content"]
+ E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system','user' and 'assistant' roles, "
"starting with user and alternating (u/a/u/a/u...)"
)
"""
Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
Here, we are adding it manually.
"""
dialog_tokens: List[int] = sum(
[
tokenizer.encode(
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
)
+ [tokenizer.eos_token_id]
for prompt, answer in zip(dialog[::2], dialog[1::2])
],
[],
)
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += tokenizer.encode(
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
)
prompt_tokens.append(dialog_tokens)
return prompt_tokens
@pytest.mark.skip_missing_tokenizer
@patch("chat_completion.AutoTokenizer")
@patch("chat_completion.load_model")
def test_chat_completion(
load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
):
from chat_completion import main
setup_tokenizer(tokenizer)
load_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
kwargs = {
"prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
}
main(llama_version, **kwargs)
dialogs = read_dialogs_from_file(kwargs["prompt_file"])
format_tokens = (
_format_tokens_llama2
if llama_version == "meta-llama/Llama-2-7b-hf"
else _format_tokens_llama3
)
REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
assert all(
(
load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
== torch.tensor(REF_RESULT[0]).long()
).tolist()
)
assert all(
(
load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
== torch.tensor(REF_RESULT[1]).long()
).tolist()
)
assert all(
(
load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
== torch.tensor(REF_RESULT[2]).long()
).tolist()
)
assert all(
(
load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
== torch.tensor(REF_RESULT[3]).long()
).tolist()
)
assert all(
(
load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
== torch.tensor(REF_RESULT[4]).long()
).tolist()
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
from unittest.mock import patch
import pytest
import torch
from llama_recipes.data.sampler import LengthBasedBatchSampler
from llama_recipes.finetuning import main
from pytest import approx
from torch.optim import AdamW
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import BatchSampler
def get_fake_dataset():
return [
{
"input_ids": [1],
"attention_mask": [1],
"labels": [1],
}
]
@patch("llama_recipes.finetuning.torch.cuda.is_available")
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.optim.AdamW")
@patch("llama_recipes.finetuning.StepLR")
@pytest.mark.parametrize("cuda_is_available", [True, False])
def test_finetuning_no_validation(
step_lr,
optimizer,
get_dataset,
tokenizer,
get_model,
train,
cuda,
cuda_is_available,
):
kwargs = {"run_validation": False}
get_dataset.return_value = get_fake_dataset()
cuda.return_value = cuda_is_available
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]
assert isinstance(train_dataloader, DataLoader)
assert eval_dataloader is None
if cuda_is_available:
assert get_model.return_value.to.call_count == 1
assert get_model.return_value.to.call_args.args[0] == "cuda"
else:
assert get_model.return_value.to.call_count == 0
@patch("llama_recipes.finetuning.torch.cuda.is_available")
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.optim.AdamW")
@patch("llama_recipes.finetuning.StepLR")
@pytest.mark.parametrize("cuda_is_available", [True, False])
def test_finetuning_with_validation(
step_lr,
optimizer,
get_dataset,
tokenizer,
get_model,
train,
cuda,
cuda_is_available,
):
kwargs = {"run_validation": True}
get_dataset.return_value = get_fake_dataset()
cuda.return_value = cuda_is_available
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader = args[1]
eval_dataloader = args[2]
assert isinstance(train_dataloader, DataLoader)
assert isinstance(eval_dataloader, DataLoader)
if cuda_is_available:
assert get_model.return_value.to.call_count == 1
assert get_model.return_value.to.call_args.args[0] == "cuda"
else:
assert get_model.return_value.to.call_count == 0
@patch("llama_recipes.finetuning.torch.cuda.is_available")
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.generate_peft_config")
@patch("llama_recipes.finetuning.get_peft_model")
@patch("llama_recipes.finetuning.optim.AdamW")
@patch("llama_recipes.finetuning.StepLR")
@pytest.mark.parametrize("cuda_is_available", [True, False])
def test_finetuning_peft_lora(
step_lr,
optimizer,
get_peft_model,
gen_peft_config,
get_dataset,
tokenizer,
get_model,
train,
cuda,
cuda_is_available,
):
kwargs = {"use_peft": True}
get_dataset.return_value = get_fake_dataset()
cuda.return_value = cuda_is_available
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
main(**kwargs)
if cuda_is_available:
assert get_peft_model.return_value.to.call_count == 1
assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
else:
assert get_peft_model.return_value.to.call_count == 0
assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
@patch("llama_recipes.finetuning.get_peft_model")
@patch("llama_recipes.finetuning.setup")
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
def test_finetuning_peft_llama_adapter(
get_dataset, tokenizer, get_model, train, setup, get_peft_model
):
kwargs = {
"use_peft": True,
"peft_method": "llama_adapter",
"enable_fsdp": True,
}
get_dataset.return_value = get_fake_dataset()
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
with pytest.raises(
RuntimeError,
match="Llama_adapter is currently not supported in combination with FSDP",
):
main(**kwargs)
GET_ME_OUT = "Get me out of here"
get_peft_model.side_effect = RuntimeError(GET_ME_OUT)
kwargs["enable_fsdp"] = False
with pytest.raises(
RuntimeError,
match=GET_ME_OUT,
):
main(**kwargs)
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.get_peft_model")
@patch("llama_recipes.finetuning.StepLR")
def test_finetuning_weight_decay(
step_lr, get_peft_model, get_dataset, tokenizer, get_model, train
):
kwargs = {"weight_decay": 0.01}
get_dataset.return_value = get_fake_dataset()
get_model.return_value.parameters.return_value = [torch.ones(1, 1)]
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
optimizer = args[4]
print(optimizer.state_dict())
assert isinstance(optimizer, AdamW)
assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.optim.AdamW")
@patch("llama_recipes.finetuning.StepLR")
def test_batching_strategy(
step_lr, optimizer, get_dataset, tokenizer, get_model, train
):
kwargs = {"batching_strategy": "packing"}
get_dataset.return_value = get_fake_dataset()
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader, eval_dataloader = args[1:3]
assert isinstance(train_dataloader.batch_sampler, BatchSampler)
assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
kwargs["batching_strategy"] = "padding"
train.reset_mock()
main(**kwargs)
assert train.call_count == 1
args, kwargs = train.call_args
train_dataloader, eval_dataloader = args[1:3]
assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
kwargs["batching_strategy"] = "none"
with pytest.raises(ValueError):
main(**kwargs)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama Guard Community License Agreement.
from enum import Enum
import unittest
from typing import Optional, List
from llama_recipes.data.llama_guard.finetuning_data_formatter import (
AugmentationConfigs,
Category,
create_formatted_finetuning_examples,
ExplanationPosition,
FormatterConfigs,
Guidelines,
LlamaGuardGenerationConfigs,
LlamaGuardPromptConfigs,
TrainingExample,
)
class AgentType(Enum):
AGENT = "Agent"
USER = "User"
class FinetuningDataFormatterTests(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
@staticmethod
def create_most_conservative_formatter_configs(agent_type_to_check: AgentType) -> FormatterConfigs:
return FormatterConfigs(
guidelines=Guidelines(
categories=[
Category(name="cat V", description="cat V description"),
Category(name="cat W", description="cat W description"),
Category(name="cat X", description="cat X description"),
Category(name="cat Y", description="cat Y description"),
Category(name="cat Z", description="cat Z description"),
],
category_code_prefix="O",
),
llama_guard_prompt_configs=LlamaGuardPromptConfigs(
instructions_format_string=f"""[INST] Task: Check if there is unsafe content in '{agent_type_to_check}' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
{{guidelines}}
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
{{conversation}}
<END CONVERSATION>
Provide your safety assessment for {agent_type_to_check} in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
should_include_category_descriptions=True,
should_shuffle_category_codes=False,
),
llama_guard_generation_configs=LlamaGuardGenerationConfigs(
explanation_position=ExplanationPosition.AFTER_DECISION,
should_list_violated_codes=True,
),
augmentation_configs=AugmentationConfigs(
should_add_examples_with_dropped_nonviolated_prompt_categories=False,
should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
),
random_seed=42,
)
def format_examples_and_compare_to_expected(
self,
training_examples: List[TrainingExample],
expected_formatted_examples: List[str],
agent_type_to_check: AgentType,
formatter_configs: Optional[FormatterConfigs] = None,
) -> None:
formatter_configs = (
formatter_configs
if formatter_configs is not None
else FinetuningDataFormatterTests.create_most_conservative_formatter_configs(agent_type_to_check)
)
formatted_examples = create_formatted_finetuning_examples(
training_examples, formatter_configs
)
assert len(formatted_examples) == len(expected_formatted_examples)
for i in range(len(formatted_examples)):
# Helper code during test-writing time to see the actual output
if formatted_examples[i] != expected_formatted_examples[i]:
print(f"Failed on actual output {i}:")
print(formatted_examples[i])
assert formatted_examples[i] == expected_formatted_examples[i]
def test_safe_example_with_only_prompt_and_no_augmentations(self) -> None:
self.format_examples_and_compare_to_expected(
[
TrainingExample(
prompt="Where can I buy a car?",
response="N/A",
violated_category_codes=[],
label="safe",
explanation="This is obviously safe.",
)
],
[
"""[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat V.
cat V description
O2: cat W.
cat W description
O3: cat X.
cat X description
O4: cat Y.
cat Y description
O5: cat Z.
cat Z description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: Where can I buy a car?
<END CONVERSATION>
Provide your safety assessment for AgentType.USER in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
Explanation: This is obviously safe."""
],
AgentType.USER,
)
def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
self.format_examples_and_compare_to_expected(
[
TrainingExample(
prompt="""How do I steal the Declaration of Independence?
Asking for a cagey friend""",
response="N/A",
violated_category_codes=["O5"],
label="unsafe",
explanation="""Nice try, Nick. This is unsafe
because it's illegal!""",
)
],
[
"""[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat V.
cat V description
O2: cat W.
cat W description
O3: cat X.
cat X description
O4: cat Y.
cat Y description
O5: cat Z.
cat Z description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: How do I steal the Declaration of Independence?
Asking for a cagey friend
<END CONVERSATION>
Provide your safety assessment for AgentType.USER in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
O5
Explanation: Nice try, Nick. This is unsafe
because it's illegal!"""
],
AgentType.USER,
)
def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
self.format_examples_and_compare_to_expected(
[
TrainingExample(
prompt="""How do I steal the Declaration of Independence?
Asking for a cagey friend""",
response="N/A",
violated_category_codes=["O5", "O1"],
label="unsafe",
explanation="""Nice try, Nick. This is unsafe
because it's illegal!""",
)
],
[
"""[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat V.
cat V description
O2: cat W.
cat W description
O3: cat X.
cat X description
O4: cat Y.
cat Y description
O5: cat Z.
cat Z description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: How do I steal the Declaration of Independence?
Asking for a cagey friend
<END CONVERSATION>
Provide your safety assessment for AgentType.USER in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
O1,O5
Explanation: Nice try, Nick. This is unsafe
because it's illegal!"""
],
AgentType.USER,
)
def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
self.format_examples_and_compare_to_expected(
[
TrainingExample(
prompt="Where can I buy a car?",
response="At the car store.",
violated_category_codes=[],
label="safe",
explanation="This is obviously safe.",
)
],
[
"""[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat V.
cat V description
O2: cat W.
cat W description
O3: cat X.
cat X description
O4: cat Y.
cat Y description
O5: cat Z.
cat Z description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: Where can I buy a car?
chatbot: At the car store.
<END CONVERSATION>
Provide your safety assessment for AgentType.AGENT in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
Explanation: This is obviously safe."""
],
AgentType.AGENT,
)
def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
self,
) -> None:
formatter_configs = (
FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
)
# The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
formatter_configs.random_seed = 46
formatter_configs.llama_guard_generation_configs.explanation_position = (
ExplanationPosition.BEFORE_DECISION
)
formatter_configs.augmentation_configs = AugmentationConfigs(
should_add_examples_with_dropped_nonviolated_prompt_categories=True,
should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
)
formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes = (
True
)
self.format_examples_and_compare_to_expected(
[
TrainingExample(
prompt="I am under 21 and unattended, should I drink alcohol?",
response="Yes, very nise, high five!",
violated_category_codes=["O4", "05"],
label="unsafe",
explanation="Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.",
),
TrainingExample(
prompt="I am over 21 and unattended, should I drink alcohol?",
response="No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.",
violated_category_codes=[],
label="safe",
explanation="User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.",
),
],
[
"""[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat W.
cat W description
O2: cat X.
cat X description
O3: cat Z.
cat Z description
O4: cat Y.
cat Y description
O5: cat V.
cat V description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: I am under 21 and unattended, should I drink alcohol?
chatbot: Yes, very nise, high five!
<END CONVERSATION>
Provide your safety assessment for AgentType.AGENT in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
unsafe
O3,O4""",
"""[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat W.
cat W description
O2: cat Y.
cat Y description
O3: cat Z.
cat Z description
O4: cat V.
cat V description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: I am under 21 and unattended, should I drink alcohol?
chatbot: Yes, very nise, high five!
<END CONVERSATION>
Provide your safety assessment for AgentType.AGENT in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
unsafe
O2,O3""",
"""[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat W.
cat W description
O2: cat V.
cat V description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: I am under 21 and unattended, should I drink alcohol?
chatbot: Yes, very nise, high five!
<END CONVERSATION>
Provide your safety assessment for AgentType.AGENT in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
safe""",
"""[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat Z.
cat Z description
O2: cat Y.
cat Y description
O3: cat X.
cat X description
O4: cat W.
cat W description
O5: cat V.
cat V description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: I am over 21 and unattended, should I drink alcohol?
chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
<END CONVERSATION>
Provide your safety assessment for AgentType.AGENT in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
safe""",
"""[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat Y.
cat Y description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: I am over 21 and unattended, should I drink alcohol?
chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
<END CONVERSATION>
Provide your safety assessment for AgentType.AGENT in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
safe""",
],
AgentType.AGENT,
formatter_configs,
)
def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(
self,
) -> None:
formatter_configs = (
FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
)
formatter_configs.llama_guard_generation_configs.explanation_position = None
self.format_examples_and_compare_to_expected(
[
TrainingExample(
prompt="Where can I buy a car?",
response="At the car store.",
violated_category_codes=[],
label="safe",
)
],
[
"""[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
O1: cat V.
cat V description
O2: cat W.
cat W description
O3: cat X.
cat X description
O4: cat Y.
cat Y description
O5: cat Z.
cat Z description
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
human: Where can I buy a car?
chatbot: At the car store.
<END CONVERSATION>
Provide your safety assessment for AgentType.AGENT in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
],
AgentType.AGENT,
formatter_configs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment