Commit b59a5620 authored by litzh's avatar litzh
Browse files

Initial commit

parents
Pipeline #3383 canceled with stages
# Copyright 2025 Tencent Inc. and the LlamaFactory team.
#
# This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING
import fire
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
if TYPE_CHECKING:
from transformers import PretrainedConfig
def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
def block_expansion(
model_name_or_path: str,
output_dir: str,
num_expand: int,
shard_size: str = "5GB",
save_safetensors: bool = True,
):
r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
num_layers = getattr(config, "num_hidden_layers")
if num_layers % num_expand != 0:
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
setattr(config, "num_hidden_layers", num_layers + num_expand)
config.save_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer.save_pretrained(output_dir)
print(f"Expanding model of {num_layers} layers to {num_layers + num_expand} layers.")
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype="auto", device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True
)
assert isinstance(model, PreTrainedModel) # type hint
if save_safetensors and getattr(model.config, "tie_word_embeddings", False):
del model.lm_head # safetensors does not allow shared weights
split = num_layers // num_expand
layer_cnt = 0
state_dict = model.state_dict()
output_state_dict: dict[str, torch.Tensor] = OrderedDict()
for i in range(num_layers):
for key, value in state_dict.items():
if f".{i:d}." in key:
output_state_dict[change_name(key, i, layer_cnt)] = value
print(f"Add layer {layer_cnt} copied from layer {i}.")
layer_cnt += 1
if (i + 1) % split == 0:
for key, value in state_dict.items():
if f".{i:d}." in key:
if "down_proj" in key or "o_proj" in key:
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print(f"Add layer {layer_cnt} expanded from layer {i}.")
layer_cnt += 1
for key, value in state_dict.items():
if key not in output_state_dict:
output_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
output_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
)
for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
shard = {tensor: output_state_dict[tensor].contiguous() for tensor in tensors}
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if not state_dict_split.is_sharded:
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
else:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print(f"Model weights saved in {output_dir}.")
print("- Fine-tune this model with:")
print(f"model_name_or_path: {output_dir}")
print("finetuning_type: freeze")
print(f"freeze_trainable_layers: {num_expand}")
print("use_llama_pro: true")
if __name__ == "__main__":
fire.Fire(block_expansion)
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import fire
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedModel
def quantize_loftq(
model_name_or_path: str,
output_dir: str,
loftq_bits: int = 4,
loftq_iter: int = 4,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,
lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True,
):
r"""Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ).
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if isinstance(lora_target, str):
lora_target = [name.strip() for name in lora_target.split(",")]
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=True,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=lora_target,
init_lora_weights="loftq",
loftq_config=loftq_config,
)
# Init LoftQ model
print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
peft_model = get_peft_model(model, lora_config)
loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print(f"Adapter weights saved in {loftq_dir}")
# Save base model
base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:")
print(f"model_name_or_path: {output_dir}")
print(f"adapter_name_or_path: {loftq_dir}")
print("finetuning_type: lora")
print(f"quantization_bit: {loftq_bits}")
if __name__ == "__main__":
fire.Fire(quantize_loftq)
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import fire
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedModel
def quantize_pissa(
model_name_or_path: str,
output_dir: str,
pissa_iter: int = 16,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,
lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True,
):
r"""Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA).
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if isinstance(lora_target, str):
lora_target = [name.strip() for name in lora_target.split(",")]
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=lora_target,
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
)
# Init PiSSA model
peft_model = get_peft_model(model, lora_config)
pissa_dir = os.path.join(output_dir, "pissa_init")
# Save PiSSA model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print(f"Adapter weights saved in {pissa_dir}")
# Save base model
base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}")
print("- Fine-tune this model with:")
print(f"model_name_or_path: {output_dir}")
print(f"adapter_name_or_path: {pissa_dir}")
print("finetuning_type: lora")
print("pissa_init: false")
print("pissa_convert: true")
print("- and optionally with:")
print("quantization_bit: 4")
if __name__ == "__main__":
fire.Fire(quantize_pissa)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Why we need this script for qwen_omni?
Because the qwen_omni model is constructed by two parts:
1. [Thinker]:[audio_encoder, vision_encoder, LLM backbone], which our repository does support to post-training.
2. [Talker]: [audio_decoder, wave_model], which is not supported to post-training without specific tokenizer.
When we post-training the model, we exactly train the [Thinker] part, and the [Talker] part is dropped.
So, to get the complete model, we need to merge the [Talker] part back to the [Thinker] part.
LoRA mode: [Thinker + LoRA weights] + [Original Talker] -> [Omni model]
Full mode: [Thinker] + [Original Talker] -> [Omni model]
For Processor, we do saved the processor from trained model instead of the original model.
"""
import os
import shutil
import fire
from peft import PeftModel
from transformers import (
AutoProcessor,
Qwen2_5OmniForConditionalGeneration, # type: ignore
Qwen2_5OmniThinkerForConditionalGeneration,
)
def merge_lora(
base_model_path: str,
lora_checkpoint_path: str,
extra_file: str = "spk_dict.pt",
submodule_name: str = "thinker",
save_path: str = "./merged_model_checkpoint",
):
"""Load the original model, merge the LoRA weights.
For a specified submodule, and save the final merged model along with its configurations.
Args:
base_model_path (str): Path to the original model directory.
lora_checkpoint_path (str): Path to the directory containing LoRA weights.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
submodule_name (str): Name of the submodule to merge (default: "thinker").
save_path (str): Directory where the merged model and configurations will be saved.
"""
# 1. Load the original model
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu")
print("Successfully loaded the original model.")
# 2. Extract the submodule to be merged (e.g., model.thinker)
if not hasattr(model, submodule_name):
raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
base_submodule = getattr(model, submodule_name)
print(f"Successfully extracted submodule: {submodule_name}.")
# 3. Load the LoRA weights onto the extracted submodule
lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
processor = AutoProcessor.from_pretrained(lora_checkpoint_path)
print("LoRA weights and processor loaded successfully.")
# 4. Merge the LoRA weights into the submodule and unload the LoRA modules
merged_submodule = lora_model.merge_and_unload()
print("LoRA weights merged successfully.")
# 5. Replace the original submodule with the merged submodule in the model
setattr(model, submodule_name, merged_submodule)
# 6. Save the final merged model along with the tokenizer and processor configuration
model.save_pretrained(save_path)
processor.save_pretrained(save_path)
print(f"Merged model and tokenizer saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
shutil.copy(source_file, target_file)
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
else:
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
def save_full_model(
saved_thinker_path: str,
base_model_path: str,
save_path: str = "./merged_model_checkpoint",
extra_file: str = "spk_dict.pt",
):
"""Load the saved thinker module and the original model, replace the thinker in the original model.
Then save the complete model along with its tokenizer and processor configuration.
Args:
saved_thinker_path (str): Path to the saved thinker weights.
base_model_path (str): Directory path of the original model.
save_path (str): Directory where the merged model and configurations will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
"""
# 1. Load the saved thinker module and the original model
thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
saved_thinker_path, torch_dtype="auto", device_map="cpu"
)
base_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
base_model_path, torch_dtype="auto", device_map="cpu"
)
base_model.thinker = thinker
# 2. Save the complete model along with its tokenizer and processor configuration
processor = AutoProcessor.from_pretrained(saved_thinker_path)
base_model.save_pretrained(save_path)
processor.save_pretrained(save_path)
print(f"Merged model and processor saved to {save_path}.")
# 3. Copy the extra file from the base model directory to the save_path
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
shutil.copy(source_file, target_file)
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
else:
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
if __name__ == "__main__":
fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})
# Copyright 2025 Microsoft Corporation and the LlamaFactory team.
#
# This code is inspired by the Microsoft's DeepSpeed library.
# https://www.deepspeed.ai/tutorials/flops-profiler/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fire
import torch
from deepspeed.accelerator import get_accelerator # type: ignore
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llamafactory.chat import ChatModel
def calculate_flops(
model_name_or_path: str,
batch_size: int = 1,
seq_length: int = 512,
flash_attn: str = "auto",
):
r"""Calculate the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
)
print("FLOPs:", flops)
print("MACs:", macs)
print("Params:", params)
if __name__ == "__main__":
fire.Fire(calculate_flops)
# Copyright 2025 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Literal
import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 2048, # i.e. maximum input length during training
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
"""
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage=stage,
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
packing=packing,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
)
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
)
else:
raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader, desc="Collecting valid tokens"):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])
valid_ratio = valid_tokens / total_tokens
token_batch_size = cutoff_len * batch_size * valid_ratio
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral_or_gemma else lr
print(
f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} "
f"and effective token batch size {token_batch_size:.2f}"
)
if __name__ == "__main__":
fire.Fire(calculate_lr)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import fire
import torch
import torch.distributed as dist
from transformers import AutoConfig
from llamafactory.train.tuner import run_exp
BASE = 2 # gemm (add + mul)
def compute_model_flops(
model_name_or_path: str,
total_batch_size: int,
seq_length: int,
include_backward: bool = True,
include_recompute: bool = False,
include_flashattn: bool = False,
) -> int:
r"""Calculate the FLOPs of model per forward/backward pass."""
config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None)
intermediate_size = getattr(config, "intermediate_size", None)
num_attention_heads = getattr(config, "num_attention_heads", None)
num_key_value_heads = getattr(config, "num_key_value_heads", None)
num_hidden_layers = getattr(config, "num_hidden_layers", None)
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
# mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
# attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size
o_flops_per_token = BASE * hidden_size * hidden_size
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
# attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
# embedding module
embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False:
embedding_flops *= 2
non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
non_embedding_coeff, embedding_coeff = 1, 1
if include_backward:
non_embedding_coeff += 2
embedding_coeff += 2
if include_recompute:
non_embedding_coeff += 1
total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
if include_flashattn:
total_flops += sdpa_flops
return total_flops
def compute_device_flops(world_size: int) -> float:
r"""Calculate the FLOPs of the device capability per second."""
device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size
elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * world_size
elif "V100" in device_name:
return 125 * 1e12 * world_size
elif "4090" in device_name:
return 98 * 1e12 * world_size
else:
raise NotImplementedError(f"Device not supported: {device_name}.")
def calculate_mfu(
model_name_or_path: str,
batch_size: int = 1,
seq_length: int = 1024,
num_steps: int = 100,
finetuning_type: str = "lora",
flash_attn: str = "auto",
deepspeed_stage: int = 0,
disable_gc: bool = False,
liger_kernel: bool = False,
unsloth_gc: bool = False,
) -> float:
r"""Calculate MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args = {
"model_name_or_path": model_name_or_path,
"flash_attn": flash_attn,
"disable_gradient_checkpointing": disable_gc,
"enable_liger_kernel": liger_kernel,
"use_unsloth_gc": unsloth_gc,
"stage": "pt",
"do_train": True,
"finetuning_type": finetuning_type,
"dataset": "c4_demo",
"cutoff_len": seq_length,
"output_dir": os.path.join("saves", "test_mfu"),
"logging_strategy": "no",
"save_strategy": "no",
"save_only_model": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": num_steps,
"bf16": True,
}
if deepspeed_stage in [2, 3]:
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args)
if dist.is_initialized():
dist.barrier()
world_size = dist.get_world_size()
else:
world_size = 1
if int(os.getenv("LOCAL_RANK", "0")) == 0:
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__":
fire.Fire(calculate_mfu)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass
from typing import Any, Literal, Optional
import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer
@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""Data collator for pairwise data."""
train_on_prompt: bool = False
def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
r"""Pad batched data to the longest sequence in the batch."""
chosen_features = []
for feature in features:
chosen_features.append(
{
"input_ids": feature["chosen_input_ids"],
"attention_mask": feature["chosen_attention_mask"],
"labels": feature["chosen_input_ids"] if self.train_on_prompt else feature["chosen_labels"],
"images": feature["images"],
"videos": feature["videos"],
"audios": feature["audios"],
}
)
return super().__call__(chosen_features)
def calculate_ppl(
model_name_or_path: str,
save_name: str = "ppl.json",
batch_size: int = 4,
stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
r"""Calculate the ppl on the dataset of the pre-trained models.
Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
stage=stage,
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
max_samples=max_samples,
train_on_prompt=train_on_prompt,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
)
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
)
elif stage == "rm":
data_collator = PairwiseDataCollatorWithPadding(
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
)
else:
raise NotImplementedError(f"Stage does not supported: {stage}.")
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: dict[str, torch.Tensor]
with torch.no_grad():
for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
shift_labels: torch.Tensor = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1)
token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item()
perplexities.extend(sentence_logps.exp().tolist())
with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2)
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
print(f"Perplexities have been saved at {save_name}.")
if __name__ == "__main__":
fire.Fire(calculate_ppl)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import fire
from tqdm import tqdm
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
def length_cdf(
model_name_or_path: str,
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
interval: int = 1000,
):
r"""Calculate the distribution of the input lengths in the dataset.
Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
"""
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=1_000_000,
preprocessing_num_workers=16,
output_dir="dummy_dir",
overwrite_cache=True,
do_train=True,
)
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"):
length_dict[len(sample) // interval * interval] += 1
length_tuples = list(length_dict.items())
length_tuples.sort()
count_accu, prob_accu = 0, 0
for length, count in length_tuples:
count_accu += count
prob_accu += count / total_num * 100
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
if __name__ == "__main__":
fire.Fire(length_cdf)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
from typing import Optional
import fire
from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.misc import get_device_count
from llamafactory.extras.packages import is_vllm_available
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
def vllm_infer(
model_name_or_path: str,
adapter_name_or_path: str = None,
dataset: str = "alpaca_en_demo",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 2048,
max_samples: Optional[int] = None,
vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95,
top_p: float = 0.7,
top_k: int = 50,
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
skip_special_tokens: bool = True,
default_system: Optional[str] = None,
enable_thinking: bool = True,
seed: Optional[int] = None,
pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32,
video_fps: float = 2.0,
video_maxlen: int = 128,
batch_size: int = 1024,
):
r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
model_args, data_args, _, generating_args = get_infer_args(
dict(
model_name_or_path=model_name_or_path,
adapter_name_or_path=adapter_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
max_samples=max_samples,
preprocessing_num_workers=16,
default_system=default_system,
enable_thinking=enable_thinking,
vllm_config=vllm_config,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
)
)
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"dtype": model_args.infer_dtype,
"max_model_len": cutoff_len + max_new_tokens,
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
"pipeline_parallel_size": pipeline_parallel_size,
"disable_log_stats": True,
"enable_lora": model_args.adapter_name_or_path is not None,
}
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
llm = LLM(**engine_args)
# load datasets
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
train_dataset = dataset_module["train_dataset"]
sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
temperature=generating_args.temperature,
top_p=generating_args.top_p or 1.0, # top_p must > 0
top_k=generating_args.top_k or -1, # top_k must > 0
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
max_tokens=generating_args.max_new_tokens,
skip_special_tokens=skip_special_tokens,
seed=seed,
)
if model_args.adapter_name_or_path is not None:
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
lora_request = None
# Store all results in these lists
all_prompts, all_preds, all_labels = [], [], []
# Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], []
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])):
if batch["images"][j] is not None:
image = batch["images"][j]
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)["images"]
}
elif batch["videos"][j] is not None:
video = batch["videos"][j]
multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos(
video,
image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels,
video_fps=video_fps,
video_maxlen=video_maxlen,
)["videos"]
}
elif batch["audios"][j] is not None:
audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios(
audio,
sampling_rate=16000,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else:
multi_modal_data = None
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
skip_special_tokens=skip_special_tokens,
)
)
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
preds = [result.outputs[0].text for result in results]
# Accumulate results
all_prompts.extend(prompts)
all_preds.extend(preds)
all_labels.extend(labels)
gc.collect()
# Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(all_prompts, all_preds, all_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
print("*" * 70)
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
print("*" * 70)
if __name__ == "__main__":
fire.Fire(vllm_infer)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from setuptools import find_packages, setup
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = {
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
"torch-npu": ["torch-npu==2.5.1", "torchvision==0.20.1", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
"liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.9.1"],
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"minicpm_v": [
"soundfile",
"torchvision",
"torchaudio",
"vector_quantize_pytorch",
"vocos",
"msgpack",
"referencing",
"jsonschema_specifications",
],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
"dev": ["pre-commit", "ruff", "pytest", "build"],
}
def main():
setup(
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga@buaa.edu.cn",
description="Unified Efficient Fine-Tuning of 100+ LLMs",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.9.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": get_console_scripts()},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
if __name__ == "__main__":
main()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import uvicorn
from llamafactory.api.app import create_app
from llamafactory.chat import ChatModel
def main():
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.getenv("API_PORT", "8000"))
print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port)
if __name__ == "__main__":
main()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
"""
from .extras.env import VERSION
__version__ = VERSION
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from contextlib import asynccontextmanager
from functools import partial
from typing import Annotated, Optional
from ..chat import ChatModel
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import (
create_chat_completion_response,
create_score_evaluation_response,
create_stream_chat_completion_response,
)
from .protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ModelCard,
ModelList,
ScoreEvaluationRequest,
ScoreEvaluationResponse,
)
if is_fastapi_available():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
async def sweeper() -> None:
while True:
torch_gc()
await asyncio.sleep(300)
@asynccontextmanager
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine.name == EngineName.HF:
asyncio.create_task(sweeper())
yield
torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
@app.get(
"/v1/models",
response_model=ModelList,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card])
@app.post(
"/v1/chat/completions",
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if request.stream:
generate = create_stream_chat_completion_response(request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
else:
return await create_chat_completion_response(request, chat_model)
@app.post(
"/v1/score/evaluation",
response_model=ScoreEvaluationResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
return await create_score_evaluation_response(request, chat_model)
return app
def run_api() -> None:
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.getenv("API_HOST", "0.0.0.0")
api_port = int(os.getenv("API_PORT", "8000"))
print(f"Visit http://localhost:{api_port}/docs for API document.")
uvicorn.run(app, host=api_host, port=api_port)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import json
import os
import re
import uuid
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Optional
from ..data import Role as DataRole
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
from .protocol import (
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
Function,
FunctionCall,
Role,
ScoreEvaluationResponse,
)
if is_fastapi_available():
from fastapi import HTTPException, status
if is_pillow_available():
from PIL import Image
if is_requests_available():
import requests
if TYPE_CHECKING:
from ..chat import ChatModel
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = logging.get_logger(__name__)
ROLE_MAPPING = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
def _process_request(
request: "ChatCompletionRequest",
) -> tuple[
list[dict[str, str]],
Optional[str],
Optional[str],
Optional[list["ImageInput"]],
Optional[list["VideoInput"]],
Optional[list["AudioInput"]],
]:
if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM:
content = request.messages.pop(0).content
system = content[0].text if isinstance(content, list) else content
else:
system = None
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
images, videos, audios = [], [], []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
tool_calls = [
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
for tool_call in message.tool_calls
]
content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
text_content = ""
for input_item in message.content:
if input_item.type == "text":
text_content += input_item.text
elif input_item.type == "image_url":
text_content += IMAGE_PLACEHOLDER
image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(image_url): # local file
image_stream = open(image_url, "rb")
else: # web uri
image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB"))
elif input_item.type == "video_url":
text_content += VIDEO_PLACEHOLDER
video_url = input_item.video_url.url
if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(video_url): # local file
video_stream = video_url
else: # web uri
video_stream = requests.get(video_url, stream=True).raw
videos.append(video_stream)
elif input_item.type == "audio_url":
text_content += AUDIO_PLACEHOLDER
audio_url = input_item.audio_url.url
if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(audio_url): # local file
audio_stream = audio_url
else: # web uri
audio_stream = requests.get(audio_url, stream=True).raw
audios.append(audio_stream)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
)
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = None
return input_messages, system, tools, images or None, videos or None, audios or None
def _create_stream_chat_completion_chunk(
completion_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional["Finish"] = None,
) -> str:
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
return jsonify(chunk)
async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images, videos, audios = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
images,
videos,
audios,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
repetition_penalty=request.presence_penalty,
stop=request.stop,
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.engine.template.extract_tool(response.response_text)
else:
result = response.response_text
if isinstance(result, list):
tool_calls = []
for tool in result:
function = Function(name=tool.name, arguments=tool.arguments)
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images, videos, audios = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
if request.n > 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
)
async for new_token in chat_model.astream_chat(
input_messages,
system,
tools,
images,
videos,
audios,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
repetition_penalty=request.presence_penalty,
stop=request.stop,
):
if len(new_token) != 0:
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
)
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse":
score_id = f"scoreval-{uuid.uuid4().hex}"
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from enum import Enum, unique
from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
@unique
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
TOOL = "tool_calls"
class ModelCard(BaseModel):
id: str
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"
class ModelList(BaseModel):
object: Literal["list"] = "list"
data: list[ModelCard] = []
class Function(BaseModel):
name: str
arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel):
id: str
type: Literal["function"] = "function"
function: Function
class URL(BaseModel):
url: str
detail: Literal["auto", "low", "high"] = "auto"
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None
image_url: Optional[URL] = None
video_url: Optional[URL] = None
audio_url: Optional[URL] = None
class ChatMessage(BaseModel):
role: Role
content: Optional[Union[str, list[MultimodalInputItem]]] = None
tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: list[ChatMessage]
tools: Optional[list[FunctionAvailable]] = None
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: int = 1
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
stream: bool = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatCompletionMessage
finish_reason: Finish
class ChatCompletionStreamResponseChoice(BaseModel):
index: int
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None
class ChatCompletionResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel):
model: str
messages: list[str]
max_length: Optional[int] = None
class ScoreEvaluationResponse(BaseModel):
id: str
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: list[float]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_engine import BaseEngine
from .chat_model import ChatModel
__all__ = ["BaseEngine", "ChatModel"]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..extras.constants import EngineName
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class BaseEngine(ABC):
r"""Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
name: "EngineName"
model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer"
can_generate: bool
template: "Template"
generating_args: dict[str, Any]
@abstractmethod
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
r"""Initialize an inference engine."""
...
@abstractmethod
async def chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> list["Response"]:
r"""Get a list of responses of the chat model."""
...
@abstractmethod
async def stream_chat(
self,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""Get the response token-by-token of the chat model."""
...
@abstractmethod
async def get_scores(
self,
batch_input: list[str],
**input_kwargs,
) -> list[float]:
r"""Get a list of scores of the reward model."""
...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment