Commit 9b0e3a30 authored by cmx's avatar cmx
Browse files

first commit

parent fe5cd1fc
Pipeline #3450 failed with stages
in 0 seconds
#!/bin/bash
torchrun --nnodes=1 --nproc-per-node=4 training_multimodal.py \
--model_name "Qwen/Qwen2-VL-7B-Instruct" \
--bf16 \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--eval_strategy "no" \
--save_strategy "no" \
--learning_rate 6e-6 \
--weight_decay 0.05 \
--warmup_ratio 0.1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--include_num_input_tokens_seen \
--report_to none \
--fsdp "full_shard auto_wrap" \
--fsdp_config config/fsdp_config.json \
--seed 42 \
--use_liger True \
--output_dir multimodal_finetuning
from dataclasses import dataclass
import datasets
import torch
import transformers
from callback import EfficiencyCallback
from trl import DataCollatorForCompletionOnlyLM
from trl import SFTTrainer
from liger_kernel.transformers import AutoLigerKernelForCausalLM
@dataclass
class CustomArguments:
model_name: str = "meta-llama/Meta-Llama-3-8B"
dataset: str = "tatsu-lab/alpaca"
max_seq_length: int = 512
use_liger: bool = False
def formatting_prompts_func(example):
return example["text"]
def train():
parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments))
training_args, custom_args = parser.parse_args_into_dataclasses()
tokenizer = transformers.AutoTokenizer.from_pretrained(
custom_args.model_name,
padding_side="left",
truncation_side="left",
)
tokenizer.pad_token = tokenizer.eos_token
dataset = datasets.load_dataset(custom_args.dataset)["train"].train_test_split(test_size=0.1)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False)
collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
response_template=response_prompt,
pad_to_multiple_of=16,
)
if custom_args.use_liger:
model = AutoLigerKernelForCausalLM.from_pretrained(
custom_args.model_name,
trust_remote_code=True,
use_cache=False,
dtype=torch.bfloat16,
# These args will get passed to the appropriate apply_liger_kernel_to_* function
# to override the default settings
# cross_entropy=True,
# fused_linear_cross_entropy=False,
)
else:
model = transformers.AutoModelForCausalLM.from_pretrained(
custom_args.model_name,
trust_remote_code=True,
use_cache=False,
dtype=torch.bfloat16,
)
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collator,
max_seq_length=custom_args.max_seq_length,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
formatting_func=formatting_prompts_func,
callbacks=[EfficiencyCallback()],
)
trainer.train()
if __name__ == "__main__":
train()
import os
from dataclasses import dataclass
import datasets
import torch
import transformers
from callback import EfficiencyCallback
from datasets import Image as ImageFeature
from trl import SFTTrainer
from liger_kernel.transformers import monkey_patch
@dataclass
class CustomArguments:
model_name: str = "Qwen/Qwen2-VL-2B-Instruct"
dataset: str = "HuggingFaceM4/the_cauldron"
dataset_subset: str = "ai2d"
dataset_split: str = "train"
max_seq_length: int = 512
dataset_text_field: str = "texts"
use_liger: bool = False
def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.Module:
if "Qwen2-VL" in model_name:
from transformers import Qwen2VLForConditionalGeneration
# These settings are used to reduce the memory footprint of the Qwen2-VL model,
# which supports training/inferences on images in their native resolution. Large
# images -> many visual tokens (a max of 16384) -> large memory consumption.
# If fine-tuning for a real-world application, consider these values carefully.
min_visual_tokens_per_image = 256
max_visual_tokens_per_image = 256
processor = transformers.AutoProcessor.from_pretrained(
model_name,
padding_side="left",
truncation_side="left",
min_pixels=min_visual_tokens_per_image * 28 * 28, # patch size is 14x14
max_pixels=max_visual_tokens_per_image * 28 * 28, # 4 patches / token
)
processor.tokenizer.pad_token = processor.tokenizer.eos_token
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
if use_liger:
print("Applying Liger Kernel to Qwen2-VL model")
monkey_patch.apply_liger_kernel_to_qwen2_vl(
# These args can be used to override the default Liger settings
# cross_entropy=True,
# fused_linear_cross_entropy=False,
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path=model_name,
use_cache=False,
dtype=torch.bfloat16,
low_cpu_mem_usage=True,
attn_implementation="sdpa",
)
return model, processor, image_token_id
raise NotImplementedError(f"Model {model_name} not supported")
def _validate_and_extract_the_cauldron(examples) -> dict[str, list]:
batch_texts = []
batch_images = []
for images, texts in zip(examples["images"], examples["texts"]):
if not images:
raise ValueError("No image found in example from the_cauldron dataset")
if len(images) > 1:
raise ValueError("Only one image per example is supported")
batch_texts.extend(texts)
batch_images.extend([images[0]] * len(texts))
return {"texts": batch_texts, "images": batch_images}
def _format_for_convo(example, tokenizer):
# cauldron data is already in message format {"user": ..., "assistant": ...}
text = example["texts"]
messages = [
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": text["user"]}],
},
{"role": "assistant", "content": [{"type": "text", "text": text["assistant"]}]},
]
text = tokenizer.apply_chat_template(messages, tokenize=False)
return {"texts": text}
def train():
parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments))
training_args, custom_args = parser.parse_args_into_dataclasses()
training_args.remove_unused_columns = False # required to not drop the image column
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
model, processor, image_token_id = construct_model_and_processor(custom_args.model_name, custom_args.use_liger)
dataset = (
datasets.load_dataset(
custom_args.dataset,
custom_args.dataset_subset,
split=custom_args.dataset_split,
)
.map(
_validate_and_extract_the_cauldron,
batched=True,
num_proc=min(os.cpu_count(), 16),
desc="Extracting text and images",
)
.map(
_format_for_convo,
fn_kwargs={"tokenizer": processor.tokenizer},
desc="Formatting for convo",
)
.cast_column("images", ImageFeature())
.train_test_split(test_size=0.1)
)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
def collate_fn(examples):
"""
Taken directly from the TRL documentation with minor modifications:
https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data
Modifications:
1. `apply_chat_template` is used to preprocess the texts before training begins (see above)
2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema
3. Ignoring image tokens in the loss computation
"""
# Get the texts and images
texts = [example["texts"] for example in examples]
images = [example["images"] for example in examples]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
# Ignore the image token index in the loss computation
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
max_seq_length=custom_args.max_seq_length,
dataset_text_field=custom_args.dataset_text_field,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=processor.tokenizer,
callbacks=[EfficiencyCallback()],
)
trainer.train()
if __name__ == "__main__":
train()
# Liger-Kernel Example with Lightning Trainer
## How to Run
```bash
pip install -r requirements.txt
# For single L40 48GB GPU
python training.py --model Qwen/Qwen2-0.5B-Instruct --num_gpu 1 --max_length 1024
# For 8XA100 40GB
python training.py --model meta-llama/Meta-Llama-3-8B --strategy deepspeed
```
**Notes**
1. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings:
* Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B
* Run `huggingface-cli login` and enter your HuggingFace token
2. The default hyperparameters and configurations for gemma works on single L40 48GB GPU and config for llama work on single node with 8xA100 40GB GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP.
<!-- Benchmark TBD -->
\ No newline at end of file
lightning
transformers
trl
liger-kernel
torch
triton
deepspeed
tf-keras
\ No newline at end of file
import argparse
import math
import os
from dataclasses import _MISSING_TYPE
from dataclasses import dataclass
import datasets
import lightning.pytorch as pl
import torch
import transformers
from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.strategies import FSDPStrategy
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp import MixedPrecision
from torch.utils.data import DataLoader
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
from trl import DataCollatorForCompletionOnlyLM
from liger_kernel.transformers import AutoLigerKernelForCausalLM
from liger_kernel.utils import infer_device
_RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"}
QUESTION = "<Question>"
CHOICES = "<Choices>"
@dataclass
class Args:
model: str = "Qwen/Qwen2-0.5B-Instruct"
data: str = "cais/mmlu"
output_dir: str = "mmlu_finetuning"
max_length: int = 2048
# for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G
batch_size: int = 4
lr: float = 6e-6
weight_decay: float = 0.05
warmup_ratio: float = 0.1
seed: int = 42
strategy: str = "auto"
num_gpu: int = None
def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):
def lr_lambda(current_step):
if current_step < warmup_steps:
# Linear warmup
return float(current_step) / float(max(1, warmup_steps))
else:
# Cosine annealing
progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))
return lr_lambda
def parse_args() -> Args:
parser = argparse.ArgumentParser()
for k, v in Args.__dataclass_fields__.items():
parser.add_argument(f"--{k}", type=v.type, default=v.default)
parsed = parser.parse_args()
return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)})
class LanguageModel(pl.LightningModule):
def __init__(self, args: Args, tokenizer):
super().__init__()
self.args = args
self.tokenizer = tokenizer
self.model = None
def configure_model(self):
# https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization
if self.model is not None:
return
self.model = AutoLigerKernelForCausalLM.from_pretrained(
self.args.model, use_cache=False, ignore_mismatched_sizes=True
)
if self.args.strategy == "deepspeed":
self.model.train()
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
def training_step(self, batch):
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
)
loss = outputs.loss
self.log_dict(
{"train_loss": loss},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
rank_zero_only=True,
sync_dist=False,
)
return loss
def validation_step(self, batch):
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
)
loss = outputs.loss
self.log_dict(
{"val_loss": outputs.loss},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
rank_zero_only=True,
sync_dist=True,
)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.args.lr,
weight_decay=self.args.weight_decay,
fused=True,
)
lr_lambda = warmup_cosine_schedule(
warmup_steps=self.trainer.estimated_stepping_batches * self.args.warmup_ratio,
total_steps=self.trainer.estimated_stepping_batches,
min_lr=0,
)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
}
class DataModule(pl.LightningDataModule):
def __init__(self, tokenizer, args: Args):
super().__init__()
self.args = args
self.tokenizer = tokenizer
self.response_template_str = " <Answer>"
response_prompt = tokenizer.encode(f"{self.response_template_str}", add_special_tokens=False)
self.collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
response_template=response_prompt,
pad_to_multiple_of=16,
)
def formatting_func(self, example):
output_texts = []
for i in range(len(example["question"])):
choices = ""
for j in range(len(example["choices"][i])):
choices += f"{j + 1}. {example['choices'][i][j]}; "
s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
s += f"{QUESTION}{example['question'][i]} "
s += f"{CHOICES}{choices} "
s += f"{self.response_template_str}{example['answer'][i]}"
output_texts.append(s)
return output_texts
def tokenize(self, example):
outputs = self.tokenizer(
self.formatting_func(example),
truncation=True,
padding=False,
max_length=self.args.max_length,
)
return {
"input_ids": outputs["input_ids"],
"attention_mask": outputs["attention_mask"],
}
def setup(self, stage) -> None:
dataset = datasets.load_dataset(self.args.data, "auxiliary_train")
flattened_data = [
{
"answer": x["train"]["answer"],
"choices": x["train"]["choices"],
"question": x["train"]["question"],
"subject": x["train"]["subject"],
}
for x in dataset["train"]
]
dataset = datasets.Dataset.from_list(flattened_data)
dataset = dataset.train_test_split(test_size=4096, seed=self.args.seed)
train_dataset, val_dataset = dataset["train"], dataset["test"]
self.train_dataset = train_dataset.map(
self.tokenize,
remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS),
batched=True,
batch_size=1,
num_proc=4,
)
self.val_dataset = val_dataset.map(
self.tokenize,
remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS),
batched=True,
batch_size=1,
num_proc=4,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.args.batch_size,
collate_fn=self.collator,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.args.batch_size,
collate_fn=self.collator,
)
def train():
args = parse_args()
pl.seed_everything(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
if "Meta-Llama-3-8B" in args.model:
layers = {LlamaDecoderLayer}
elif "Qwen2" in args.model:
layers = {Qwen2DecoderLayer}
else:
layers = {}
raise Warning(f"Unimplemented layer wrap policy for {args.model} in this example")
if args.strategy == "fsdp":
strategy = FSDPStrategy(
auto_wrap_policy=layers,
sharding_strategy="FULL_SHARD",
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
sync_module_states=True,
activation_checkpointing_policy=layers,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),
forward_prefetch=True,
)
precision = None
elif args.strategy == "deepspeed":
strategy = DeepSpeedStrategy(stage=3)
precision = "bf16-mixed"
elif args.strategy == "ddp":
strategy = "ddp"
precision = "bf16-true"
else:
strategy = "auto"
precision = "bf16-true"
device = infer_device()
trainer = pl.Trainer(
accelerator=device,
strategy=strategy,
devices=(getattr(torch, device).device_count() if args.num_gpu is None else args.num_gpu),
default_root_dir=args.output_dir,
log_every_n_steps=1,
max_epochs=1,
precision=precision,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, padding_side="left", truncation_side="left")
tokenizer.pad_token = tokenizer.eos_token
data_module = DataModule(
tokenizer=tokenizer,
args=args,
)
model = LanguageModel(args=args, tokenizer=tokenizer)
trainer.fit(model, datamodule=data_module)
if __name__ == "__main__":
train()
# Liger-Kernel Example with Medusa
Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads. [[repo](https://arxiv.org/abs/2401.10774)], [[paper](https://arxiv.org/abs/2401.10774)]
During training, Medusa requires adding \(k\) decoding heads to the hidden states right before the regular LM head \(h_t\). The \(k\)-th head is used to predict the token in the \((t + k + 1)\)-th position of the next tokens (the original language model head is used to predict the \((t + 1)\)-th position).
The Liger fused CE kernel is highly effective in this scenario, eliminating the need to materialize logits for each head, which usually consumes a large volume of memory due to the extensive vocabulary size (e.g., for LLaMA-3, the vocabulary size is 128k). The introduction of multiple heads can easily lead to OOM (Out of Memory) issues. However, thanks to the efficient Liger fused CE, which calculates the gradient in place and doesn't materialize the logits, we have observed very effective results. This efficiency opens up more opportunities for multi-token prediction research and development.
# Instructions to Run the Training Script
```
git clone git@github.com:linkedin/Liger-Kernel.git
cd {PATH_TO_Liger-Kernel}/Liger-Kernel/
pip install -e .
cd {PATH_TO_Liger-Kernel}/Liger-Kernel/examples/medusa
pip install -r requirements.txt
sh scripts/llama3_8b_medusa.sh
```
**Notes**
1. This example uses an optional `use_liger` flag. If true, it does a monkey patch to apply liger kernel with medusa heads.
2. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings:
* Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B
* Run `huggingface-cli login` and enter your HuggingFace token
3. The default hyperparameters and configurations work on single node with 8xA100 GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP.
4. We are using a smaller sample of shared GPT data primarily to benchmark performance. The example requires hyperparameter tuning and dataset selection to work effectively, also ensuring the dataset has the same distribution as the LLaMA pretraining data. Welcome contribution to enhance the example code.
# Memory Profiling Result
> **Note:**
> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 6, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
## Stage1
Stage1 refers to Medusa-1 where the backbone model is frozen and only weights of LLM heads are updated.
```
# Modify this flag in llama3_8b_medusa.sh to True enables stage1
--medusa_only_heads True
```
### num_head = 3
![Memory](./docs/images/Memory_Stage1_num_head_3.png)
![Throughput](./docs/images/Throughput_Stage1_num_head_3.png)
### num_head = 5
![Memory](./docs/images/Memory_Stage1_num_head_5.png)
![Throughput](./docs/images/Throughput_Stage1_num_head_5.png)
## Stage2
```
# Modify this flag to False in llama3_8b_medusa.sh enables stage2
--medusa_only_heads False
```
Stage2 refers to Medusa-2 where all the model weights are updated incuding backbone model and llm heads.
### num_head = 3
![Memory](./docs/images/Memory_Stage2_num_head_3.png)
![Throughput](./docs/images/Throughput_Stage2_num_head_3.png)
### num_head = 5
![Memory](./docs/images/Memory_Stage2_num_head_5.png)
![Throughput](./docs/images/Throughput_Stage2_num_head_5.png)
import os
import time
from dataclasses import dataclass
import torch
import transformers
from accelerate.utils.constants import FSDP_SHARDING_STRATEGY
from transformers import TrainerControl
from transformers import TrainerState
from transformers import TrainingArguments
from liger_kernel.utils import infer_device
# https://simple.wikipedia.org/wiki/Byte
# For memory, we use binary system
M_BIN_UNIT = 2**20
# For metrics (tflops), we use decimal system
T_DEC_UNIT = 10**12
def round_to_n_decimal(x, n):
return round(x, n)
@dataclass
class Precision:
"""
Precision is a dataclass to store the number of decimal points for each metric.
"""
n_decimal_time: int
n_decimal_memory: int
n_decimal_TPS: int
n_decimal_MFU: int
@dataclass
class State:
"""
State is a dataclass to store the internal state of the efficiency callback.
"""
n_warmup_steps: int = 0
total_peak_memory_allocated: float = float("-inf")
total_peak_memory_reserved: float = float("-inf")
step_start_time: float = 0.0
elapsed_time: float = 0.0
elapsed_step: int = 0
step_start_tokens_seen: int = 0
elapsed_tokens_seen: int = 0
step_start_flos: float = 0.0
elapsed_flos: float = 0.0
global_start_step: int = 0
@dataclass
class Time:
"""
Time is a dataclass to store the time-related metrics.
"""
step: int = 0
step_time_sec: float = 0.0
avg_step_time_sec: float = 0.0
time_to_completion_sec: float = 0.0
estimated_total_time_sec: float = 0.0
@dataclass
class Memory:
"""
Memory is a dataclass to store the memory-related metrics.
"""
step_peak_memory_allocated_MB: float = 0.0
total_peak_memory_allocated_MB: float = 0.0
@dataclass
class TPS:
"""
TPS is a dataclass to store the tokens per second metrics.
"""
step_tokens_per_second: float = 0.0
avg_tokens_per_second: float = 0.0
@dataclass
class MFU:
"""
MFU is a dataclass to store the MFU metrics.
"""
step_MFU: float = 0.0
avg_MFU: float = 0.0
class EfficiencyCallback(transformers.TrainerCallback):
"""
EfficiencyCallback is a callback to track the efficiency of the training process.
The tracked stats include: step time, memory, throughput, and MFU.
It requires including `--include_num_input_tokens_seen` and `logging_steps=1` in the training arguments.
Args:
n_warmup_steps: number of warmup steps
The stats in the first n_warmup_steps will not be added into the aggregated stats
This is because the first few steps might take longer due to jit compliation and other initialization overheads
n_decimal_time: number of decimal points for time
n_decimal_memory: number of decimal points for memory
n_decimal_TPS: number of decimal points for TPS
n_decimal_MFU: number of decimal points for MFU in percentage
"""
def __init__(
self,
n_warmup_steps=2,
n_decimal_time=2,
n_decimal_memory=2,
n_decimal_TPS=2,
n_decimal_MFU=4,
):
self.state = State(
n_warmup_steps,
)
self.precision = Precision(
n_decimal_time,
n_decimal_memory,
n_decimal_TPS,
n_decimal_MFU,
)
self.time = Time()
self.memory = Memory()
self.tps = TPS()
self.mfu = MFU()
self.device = infer_device()
def on_init_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""
Event called at the end of the initialization of the [`Trainer`].
"""
if not args.include_num_input_tokens_seen:
raise Exception(
'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second'
)
if args.logging_steps != 1:
raise Exception("Please set logging_steps=1 to track the efficiency metrics accurately")
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# if loaded from checkpoints, global_start_step is not 1 but state.global_step
self.state.global_start_step = state.global_step
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: dict[str, float],
**kwargs,
):
if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps):
return
else:
# spread self.time, self.memory, self.tps, self.mfu to logs
# logs.update(self.time.__dict__)
logs.update(self.memory.__dict__)
logs.update(self.tps.__dict__)
# logs.update(self.mfu.__dict__)
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
# memory
getattr(torch, self.device).reset_peak_memory_stats()
# time
self.state.step_start_time = time.perf_counter()
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps):
# The end the current step_start_tokens_seen and step_start_flos are the start of next iteration
# tokens
self.state.step_start_tokens_seen = state.num_input_tokens_seen
# flos
self.state.step_start_flos = state.total_flos
return
# time
current_time = time.perf_counter()
step_time = current_time - self.state.step_start_time
self.state.elapsed_time += step_time
# step
global_step = state.global_step
self.state.elapsed_step += 1
avg_step_time = self.state.elapsed_time / self.state.elapsed_step
self.time.step = global_step
self.time.step_time_sec = round_to_n_decimal(step_time, self.precision.n_decimal_time)
self.time.avg_step_time_sec = round_to_n_decimal(avg_step_time, self.precision.n_decimal_time)
self.time.time_to_completion_sec = round_to_n_decimal(
avg_step_time * (state.max_steps - global_step),
self.precision.n_decimal_time,
)
self.time.estimated_total_time_sec = round_to_n_decimal(
avg_step_time * state.max_steps, self.precision.n_decimal_time
)
# memory
step_peak_memory_allocated = getattr(torch, self.device).memory.max_memory_allocated()
step_peak_memory_reserved = getattr(torch, self.device).memory.max_memory_reserved()
self.memory.step_peak_memory_allocated_MB = round_to_n_decimal(
step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory
)
self.state.total_peak_memory_allocated = max(self.state.total_peak_memory_allocated, step_peak_memory_allocated)
self.memory.total_peak_memory_allocated_MB = round_to_n_decimal(
self.state.total_peak_memory_allocated / M_BIN_UNIT,
self.precision.n_decimal_memory,
)
self.memory.step_peak_memory_reserved_MB = round_to_n_decimal(
step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory
)
self.state.total_peak_memory_reserved = max(self.state.total_peak_memory_reserved, step_peak_memory_reserved)
self.memory.total_peak_memory_reserved_MB = round_to_n_decimal(
self.state.total_peak_memory_reserved / M_BIN_UNIT,
self.precision.n_decimal_memory,
)
# tokens
step_tokens_seen = state.num_input_tokens_seen - self.state.step_start_tokens_seen
self.state.elapsed_tokens_seen += step_tokens_seen
self.tps.step_tokens_per_second = round_to_n_decimal(
step_tokens_seen / step_time,
self.precision.n_decimal_TPS,
)
self.tps.avg_tokens_per_second = round_to_n_decimal(
self.state.elapsed_tokens_seen / self.state.elapsed_time,
self.precision.n_decimal_TPS,
)
# flos
step_flos = state.total_flos - self.state.step_start_flos
self.state.elapsed_flos += step_flos
# MFU
# 1. Definition
#
# MFU is defined as (achieved TPS) / (theoretical maximum TPS) = (achieved floating point operations per sec) / (theoretical maximum floating point operations per sec)
# Crucially, the "theoretical maximum" throughput only accounts for the required operations to compute the forward+backward passes, and not rematerialization. MFU therefore allows fair comparisons
# between training runs on different systems, as the numerator is simply the observed tokens-per-second, and the denominator is only dependent on the model architecture and published maximum FLOPs for a given system.
# Ref: https://arxiv.org/pdf/2204.02311
# The benefit of MFU is that it
#
# 2. Implementation in huggingface
#
# current_flos = 6 * estimate_tokens(input_dict) * num_parameters()
# total_flos = sum(current_flos) # across all GPUs
# Ref: https://github.com/huggingface/transformers/blob/616bb11d487aabc231bb230b245c42214ea4b254/src/transformers/modeling_utils.py#L1196
#
# 3. Derive MFU on rank 0
#
# rank_0_flos = tatal_flos / n_gpus = measured_flos / effecitve_n_gpus
# rank_0_MFU = rank_0_flos / step_time
#
# For FSDP, num_parameters() is (1 / n_gpus) of the total parameters. So, the effective_n_gpus = 1
# For HSDP, num_parameters() is (1 / local_world_size) of the total parameters. So, the effective_n_gpus = n_nodes
# For no sharding and zero-2, num_parameters() is the total parameters. So, the effective_n_gpus = n_gpus
num_gpus = EfficiencyCallback._get_effective_num_gpus()
step_achieved_tflops = step_flos / step_time / num_gpus / T_DEC_UNIT
avg_achieved_tflops = self.state.elapsed_flos / self.state.elapsed_time / num_gpus / T_DEC_UNIT
precision_bits = 16 if args.bf16 or args.fp16 else 32
gpu_peak_tflops = EfficiencyCallback._get_gpu_peak_tflops(precision_bits)
self.mfu.step_MFU = round_to_n_decimal(step_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU)
self.mfu.avg_MFU = round_to_n_decimal(avg_achieved_tflops / gpu_peak_tflops, self.precision.n_decimal_MFU)
# The end the current step_start_tokens_seen and step_start_flos are the start of next iteration
# tokens
self.state.step_start_tokens_seen = state.num_input_tokens_seen
# flos
self.state.step_start_flos = state.total_flos
@staticmethod
def _get_effective_num_gpus():
# Calculate the number of effective GPUs for the total FLOPs in order to calculate the single GPU FLOP
world_size = int(os.environ.get("WORLD_SIZE", "1"))
if transformers.utils.strtobool(os.environ.get("ACCELERATE_USE_FSDP", "false")):
sharding_strategy = os.environ.get("FSDP_SHARDING_STRATEGY", FSDP_SHARDING_STRATEGY[0]).upper()
# Either specified as string or enum number
if sharding_strategy in {
"FULL_SHARD",
str(FSDP_SHARDING_STRATEGY.index("FULL_SHARD") + 1),
}:
return 1
elif sharding_strategy in {
"HYBRID_SHARD",
str(FSDP_SHARDING_STRATEGY.index("HYBRID_SHARD") + 1),
}:
return world_size // int(os.environ.get("LOCAL_WORLD_SIZE", 1))
else:
return world_size
assert world_size != 0, (
"WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
)
# TODO: add deepspeed support
return world_size
@staticmethod
def _get_gpu_peak_tflops(precision_bits: int = 16):
if precision_bits not in {16, 32}:
raise Exception(f"Precision bits {precision_bits} is not supported")
device_name = getattr(torch, infer_device()).get_device_name()
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312 if precision_bits == 16 else 156
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 1979 if precision_bits == 16 else 989
elif "PCIe" in device_name:
return 756 if precision_bits == 16 else 378
else: # for SXM and other variants
return 989 if precision_bits == 16 else 494
elif "V100" in device_name:
if "NVL" in device_name:
return 125
else:
return 112
return None
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'yes'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: NO_PREFETCH
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
main_training_function: main
mixed_precision: bf16
rdzv_backend: static
same_network: true
num_machines: 1
num_processes: 1
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: yes
\ No newline at end of file
import types
from typing import List
from typing import Optional
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
class MedusaConfig(PretrainedConfig):
"""
Configuration class for Medusa model.
Args:
medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
num_unfreezed_layers (int, optional): Number of layers to unfreeze. Default is 0.
**kwargs: Additional keyword arguments to be passed to the parent class constructor.
"""
def __init__(
model,
medusa_num_heads=4,
medusa_num_layers=1,
base_model_name_or_path="/shared/public/models/Meta-Llama-3-8B",
**kwargs,
):
super().__init__(**kwargs)
model.medusa_num_heads = medusa_num_heads
model.medusa_num_layers = medusa_num_layers
model.base_model_name_or_path = base_model_name_or_path
class ResBlock(nn.Module):
"""
A Residual Block module.
This module performs a linear transformation followed by a SiLU activation,
and then adds the result to the original input, creating a residual connection.
Args:
hidden_size (int): The size of the hidden layers in the block.
"""
def __init__(model, hidden_size):
super().__init__()
model.linear = nn.Linear(hidden_size, hidden_size)
# Initialize as an identity mapping
nn.init.zeros_(model.linear.weight)
# Use SiLU activation to keep consistent with the Llama model
model.act = nn.SiLU()
def forward(model, x):
"""
Forward pass of the ResBlock.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output after the residual connection and activation.
"""
return x + model.act(model.linear(x))
def calculate_loss_contribution(
loss_i,
i,
medusa_only_heads,
medusa_decay_coefficient,
medusa_heads_coefficient,
medusa_scheduler_coefficient,
):
if i == 0:
return loss_i if not medusa_only_heads else 0
else:
return loss_i * medusa_decay_coefficient**i * medusa_heads_coefficient * medusa_scheduler_coefficient
def add_medusa_heads(
model,
medusa_num_heads=4,
medusa_num_layers=0,
medusa_return: bool = False,
medusa_only_heads: bool = False,
with_liger=True,
):
"""
Args:
model (nn.Module): The base language model to be used.
medusa_num_heads (int, optional): The number of additional tokens to predict. Defaults to 3.
medusa_num_layers (int, optional): The number of ResBlock layers for each Medusa head. Defaults to 0.
medusa_return (bool, optional): If True, returns the Medusa logits; otherwise, the forward pass will use the `lm_head`. Defaults to False.
medusa_only_heads (bool, optional): If True, only the Medusa head weights will be updated during fine-tuning; otherwise, the entire model's weights will be updated. Defaults to False.
with_liger (bool, optional): If True, applies Liger loss. Defaults to True.
"""
hidden_size = model.lm_head.weight.shape[-1]
vocab_size = model.lm_head.weight.shape[0]
model.config.medusa_num_layers = medusa_num_layers
model.config.medusa_num_heads = medusa_num_heads
model.medusa_num_heads = medusa_num_heads
# Create a list of Medusa heads
model.medusa_head = nn.ModuleList(
[
nn.Sequential(
*([ResBlock(hidden_size) for _ in range(medusa_num_layers)]),
nn.Linear(hidden_size, vocab_size, bias=False),
)
for _ in range(medusa_num_heads)
]
)
# Ensure medusa_head's dtype and device align with the base_model
model.medusa_head.to(model.dtype).to(model.device)
for i in range(medusa_num_heads):
# Initialize the weights of each medusa_head using the base model's weights
model.medusa_head[i][-1].weight.data[:] = model.lm_head.weight.data[:]
# logging the model summary
print(model)
model.old_forward = model.forward
def forward(
model,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""Forward pass of the MedusaModel.
Returns:
torch.Tensor: A tensor containing predictions from all Medusa heads.
(Optional) Original predictions from the base model's LM head.
"""
loss = 0
medusa_logits = None
# LOG.debug("medusa_return: %s", medusa_return)
if not medusa_return:
return model.old_forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Pass input through the base model
if medusa_only_heads:
with torch.no_grad():
outputs = model.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
# The lm_head will be frozen as well, so it's within the context of torch.no_grad()
if not with_liger:
medusa_logits = [model.lm_head(hidden_states)]
else:
outputs = model.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if not with_liger:
medusa_logits = [model.lm_head(hidden_states)]
if not with_liger:
for i in range(model.medusa_num_heads):
medusa_logits.append(model.medusa_head[i](hidden_states))
medusa_logits = torch.stack(medusa_logits, dim=0)
if model.training:
# Fix all the coefficients to 1 for now
medusa_scheduler_coefficient = 1
medusa_heads_coefficient = 1
medusa_decay_coefficient = 1
loss = 0
if with_liger:
lce = LigerFusedLinearCrossEntropyLoss()
for i in range(model.medusa_num_heads + 1):
shift_hidden_states = (
hidden_states[..., : -(1 + i), :].contiguous().view(-1, model.config.hidden_size)
)
shift_labels = labels[..., (1 + i) :].contiguous().view(-1)
weight = model.lm_head.weight if i == 0 else model.medusa_head[i - 1][-1].weight
loss_i = lce(weight, shift_hidden_states, shift_labels)
loss += calculate_loss_contribution(
loss_i,
i,
medusa_only_heads,
medusa_decay_coefficient,
medusa_heads_coefficient,
medusa_scheduler_coefficient,
)
else:
loss_fct = CrossEntropyLoss()
for i in range(model.medusa_num_heads + 1):
medusa_logits_i = medusa_logits[i, :, : -(1 + i)].contiguous().view(-1, medusa_logits.shape[-1])
medusa_logits_i = medusa_logits_i.float()
medusa_labels = labels[..., (1 + i) :].contiguous().view(-1).to(medusa_logits_i.device)
loss_i = loss_fct(medusa_logits_i, medusa_labels)
loss += calculate_loss_contribution(
loss_i,
i,
medusa_only_heads,
medusa_decay_coefficient,
medusa_heads_coefficient,
medusa_scheduler_coefficient,
)
else:
if model.config.pretraining_tp > 1:
raise NotImplementedError
else:
medusa_logits = [model.lm_head(hidden_states)]
for i in range(model.medusa_num_heads):
medusa_logits.append(model.medusa_head[i](hidden_states))
return_dict = return_dict if return_dict is not None else model.config.use_return_dict
if not return_dict:
output = (medusa_logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=medusa_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
model.forward = types.MethodType(forward, model)
accelerate==1.6.0
scikit-learn
transformers==4.51.3
#!/bin/sh
export GPUS_PER_NODE=$(nvidia-smi --list-gpus | wc -l)
export LOCAL_WORLD_SIZE=$GPUS_PER_NODE
export NUM_NODES=$WORLD_SIZE
export WORLD_SIZE=$((GPUS_PER_NODE * NUM_NODES))
echo "Starting training... Num nodes: $NUM_NODES, Num workers: $WORLD_SIZE"
export OUTPUT_DIR="./llama3-8b-medusa-liger"
export LOCAL_TRAIN_BATCH_SIZE=4
export GRADIENT_ACCUMULATION_STEPS=1
export LR=1e-5
export MEDUSA_NUM_HEADS=5
export MEDUSA_NUM_LAYERS=1
export MEDUSA_HEADS_COEFFICIENT=0.2
export MEDUSA_DECAY_COEFFICIENT=0.8
export MEDUSA_SCHEDULER=constant
export MEDUSA_LR_MULTIPLIER=4.0
accelerate launch --config_file fsdp/acc-fsdp.conf \
--num_machines $NUM_NODES \
--num_processes $WORLD_SIZE \
train.py \
--bf16 True \
--output_dir $OUTPUT_DIR \
--num_train_epochs 10 \
--per_device_train_batch_size $LOCAL_TRAIN_BATCH_SIZE \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
--eval_strategy "no" \
--save_strategy "no" \
--prediction_loss_only \
--learning_rate $LR \
--weight_decay 0. \
--warmup_ratio 0.04 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 1024 \
--gradient_checkpointing True \
--lazy_preprocess False \
--report_to none \
--include_num_input_tokens_seen \
--medusa_num_heads $MEDUSA_NUM_HEADS \
--medusa_num_layers $MEDUSA_NUM_LAYERS \
--medusa_heads_coefficient $MEDUSA_HEADS_COEFFICIENT \
--medusa_decay_coefficient $MEDUSA_DECAY_COEFFICIENT \
--medusa_scheduler $MEDUSA_SCHEDULER \
--medusa_lr_multiplier $MEDUSA_LR_MULTIPLIER \
--medusa_only_heads False \
--medusa_return True \
--use_liger True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment