Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
......@@ -15,7 +15,7 @@
import json
import os
from collections import OrderedDict
from typing import Any, Dict
from typing import Any
import fire
import torch
......@@ -37,14 +37,14 @@ CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
qwen_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
for key in f.keys():
qwen_state_dict[key] = f.get_tensor(key)
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict()
llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
torch_dtype = None
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
if torch_dtype is None:
......@@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f)
qwen_config_dict: dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict()
llama2_config_dict: dict[str, Any] = OrderedDict()
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict["hidden_act"] = "silu"
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
......@@ -147,8 +147,8 @@ def llamafy_qwen(
shard_size: str = "2GB",
save_safetensors: bool = False,
):
r"""
Converts the Qwen models in the same format as LLaMA2.
r"""Convert the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
"""
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
if __name__ == "__main__":
vision_config = Llama4VisionConfig(
hidden_size=1408,
image_size=336,
intermediate_size=5632,
num_attention_heads=16,
num_hidden_layers=4,
vision_output_dim=4096,
)
text_config = Llama4TextConfig(
hidden_size=512,
intermediate_size=1024,
intermediate_size_mlp=1024,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=2,
head_dim=512 // 8,
num_local_experts=2,
)
config = Llama4Config(vision_config=vision_config, text_config=text_config)
model = Llama4ForConditionalGeneration._from_config(config)
model.save_pretrained("tiny-llama4")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import time
import fire
from datasets import load_dataset
try:
import jieba
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from rouge_chinese import Rouge
jieba.setLogLevel(logging.CRITICAL)
jieba.initialize()
except ImportError:
print("Please install llamafactory with `pip install -e .[metrics]`.")
raise
def compute_metrics(sample):
hypothesis = list(jieba.cut(sample["predict"]))
reference = list(jieba.cut(sample["label"]))
bleu_score = sentence_bleu(
[list(sample["label"])],
list(sample["predict"]),
smoothing_function=SmoothingFunction().method3,
)
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
else:
rouge = Rouge()
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
result = scores[0]
metric_result = {}
for k, v in result.items():
metric_result[k] = round(v["f"] * 100, 4)
metric_result["bleu-4"] = round(bleu_score * 100, 4)
return metric_result
def main(filename: str):
start_time = time.time()
dataset = load_dataset("json", data_files=filename, split="train")
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
score_dict = dataset.to_dict()
average_score = {}
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
print(f"{task}: {sum(scores) / len(scores):.4f}")
average_score[task] = sum(scores) / len(scores)
with open("predictions_score.json", "w", encoding="utf-8") as f:
json.dump(average_score, f, indent=4)
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to predictions_score.json")
if __name__ == "__main__":
fire.Fire(main)
......@@ -18,7 +18,7 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
import fire
import torch
......@@ -44,11 +44,11 @@ def block_expansion(
shard_size: str = "5GB",
save_safetensors: bool = True,
):
r"""
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
num_layers = getattr(config, "num_hidden_layers")
if num_layers % num_expand != 0:
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
......@@ -70,7 +70,7 @@ def block_expansion(
split = num_layers // num_expand
layer_cnt = 0
state_dict = model.state_dict()
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict()
output_state_dict: dict[str, torch.Tensor] = OrderedDict()
for i in range(num_layers):
for key, value in state_dict.items():
if f".{i:d}." in key:
......
......@@ -38,8 +38,8 @@ def quantize_loftq(
lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
r"""Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ).
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if isinstance(lora_target, str):
......@@ -72,7 +72,7 @@ def quantize_loftq(
print(f"Adapter weights saved in {loftq_dir}")
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}")
......
......@@ -37,8 +37,8 @@ def quantize_pissa(
lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
r"""Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA).
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if isinstance(lora_target, str):
......@@ -67,7 +67,7 @@ def quantize_pissa(
print(f"Adapter weights saved in {pissa_dir}")
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}")
......
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import fire
from peft import PeftModel
from transformers import AutoModel, AutoProcessor, AutoTokenizer, Qwen2_5OmniThinkerForConditionalGeneration
def merge_lora(
base_model_path: str,
lora_checkpoint_path: str,
extra_file: str = "spk_dict.pt",
submodule_name: str = "thinker",
save_path: str = "./merged_model_checkpoint",
):
"""Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
For a specified submodule, and save the final merged model along with its configurations.
Args:
base_model_path (str): Path to the original model directory.
lora_checkpoint_path (str): Path to the directory containing LoRA weights.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
submodule_name (str): Name of the submodule to merge (default: "thinker").
save_path (str): Directory where the merged model and configurations will be saved.
"""
# 1. Load the original model, tokenizer, and processor
model = AutoModel.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
try:
processor = AutoProcessor.from_pretrained(base_model_path)
except Exception:
print("Processor configuration not found, skipping processor load.")
processor = None
print("Successfully loaded the original model, tokenizer, and processor (if available).")
# 2. Extract the submodule to be merged (e.g., model.thinker)
if not hasattr(model, submodule_name):
raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
base_submodule = getattr(model, submodule_name)
print(f"Successfully extracted submodule: {submodule_name}.")
# 3. Load the LoRA weights onto the extracted submodule
lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
print("LoRA weights loaded successfully.")
# 4. Merge the LoRA weights into the submodule and unload the LoRA modules
merged_submodule = lora_model.merge_and_unload()
print("LoRA weights merged successfully.")
# 5. Replace the original submodule with the merged submodule in the model
setattr(model, submodule_name, merged_submodule)
# 6. Save the final merged model along with the tokenizer and processor configuration
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
if processor is not None:
processor.save_pretrained(save_path)
print(f"Merged model and configuration saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
shutil.copy(source_file, target_file)
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
else:
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
def save_full_model(
saved_thinker_path: str,
base_model_path: str,
save_path: str,
extra_file: str = "spk_dict.pt",
):
"""Load the saved thinker module and the original model, replace the thinker in the original model.
Then save the complete model along with its tokenizer and processor configuration.
Args:
saved_thinker_path (str): Path to the saved thinker weights.
base_model_path (str): Directory path of the original model.
save_path (str): Directory where the final complete model will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
"""
# Load the thinker module
thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(saved_thinker_path, device_map="cpu")
# Load the original model
base_model = AutoModel.from_pretrained(base_model_path, device_map="cpu")
# Replace the thinker module in the original model
base_model.thinker = thinker
# Load the processor and tokenizer
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
# Save the complete model along with its configurations
base_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
processor.save_pretrained(save_path)
print(f"Complete model, tokenizer, and processor configuration have been saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
shutil.copy(source_file, target_file)
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
else:
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
if __name__ == "__main__":
fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})
......@@ -29,8 +29,8 @@ def calculate_flops(
seq_length: int = 512,
flash_attn: str = "auto",
):
r"""
Calculates the flops of pre-trained models.
r"""Calculate the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
with get_accelerator().device(0):
......
......@@ -45,8 +45,8 @@ def calculate_lr(
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
"""
......@@ -89,9 +89,8 @@ def calculate_lr(
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral_or_gemma else lr
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format(
lr, valid_ratio * 100, token_batch_size
)
f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} "
f"and effective token batch size {token_batch_size:.2f}"
)
......
......@@ -34,9 +34,7 @@ def compute_model_flops(
include_recompute: bool = False,
include_flashattn: bool = False,
) -> int:
r"""
Calculates the FLOPs of model per forward/backward pass.
"""
r"""Calculate the FLOPs of model per forward/backward pass."""
config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None)
......@@ -86,9 +84,7 @@ def compute_model_flops(
def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
r"""Calculate the FLOPs of the device capability per second."""
device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size
......@@ -114,8 +110,8 @@ def calculate_mfu(
liger_kernel: bool = False,
unsloth_gc: bool = False,
) -> float:
r"""
Calculates MFU for given model and hyper-params.
r"""Calculate MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args = {
......
......@@ -14,7 +14,7 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Sequence
from typing import Any, Literal, Optional
import fire
import torch
......@@ -30,16 +30,12 @@ from llamafactory.model import load_model, load_tokenizer
@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
r"""Data collator for pairwise data."""
train_on_prompt: bool = False
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
"""
def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
r"""Pad batched data to the longest sequence in the batch."""
chosen_features = []
for feature in features:
chosen_features.append(
......@@ -68,8 +64,8 @@ def calculate_ppl(
max_samples: Optional[int] = None,
train_on_prompt: bool = False,
):
r"""
Calculates the ppl on the dataset of the pre-trained models.
r"""Calculate the ppl on the dataset of the pre-trained models.
Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
......@@ -111,17 +107,17 @@ def calculate_ppl(
criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = []
batch: Dict[str, "torch.Tensor"]
batch: dict[str, torch.Tensor]
with torch.no_grad():
for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
shift_labels: torch.Tensor = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item()
......
......@@ -29,8 +29,8 @@ def length_cdf(
template: str = "default",
interval: int = 1000,
):
r"""
Calculates the distribution of the input lengths in the dataset.
r"""Calculate the distribution of the input lengths in the dataset.
Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
"""
......
......@@ -52,11 +52,11 @@ def vllm_infer(
image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32,
):
r"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
check_version("vllm>=0.4.3,<=0.7.3")
check_version("vllm>=0.4.3,<=0.8.2")
if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
......@@ -92,8 +92,20 @@ def vllm_infer(
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)
)["images"]
}
elif sample["videos"]:
multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos(
sample["videos"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)["videos"]
}
elif sample["audios"]:
audio_data = template_obj.mm_plugin._regularize_audios(
sample["audios"],
sampling_rate=16000,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else:
multi_modal_data = None
......@@ -131,7 +143,7 @@ def vllm_infer(
"enable_lora": model_args.adapter_name_or_path is not None,
}
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
......
......@@ -14,7 +14,6 @@
import os
import re
from typing import List
from setuptools import find_packages, setup
......@@ -27,14 +26,14 @@ def get_version() -> str:
return version
def get_requires() -> List[str]:
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> List[str]:
def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
......@@ -47,14 +46,15 @@ extra_require = {
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.4"],
"liger-kernel": ["liger-kernel"],
"liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.7.3"],
"vllm": ["vllm>=0.4.3,<=0.8.2"],
"sglang": ["sglang[srt]>=0.4.4", "transformers==4.48.3"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"],
......@@ -69,6 +69,7 @@ extra_require = {
"msgpack",
"referencing",
"jsonschema_specifications",
"transformers==4.48.3",
],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
......@@ -82,11 +83,11 @@ def main():
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga AT buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework",
author_email="hiyouga@buaa.edu.cn",
description="Unified Efficient Fine-Tuning of 100+ LLMs",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
......
......@@ -12,18 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Efficient fine-tuning of large language models.
r"""Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.2.0
accelerate>=0.34.0,<=1.2.1
peft>=0.11.1,<=0.12.0
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.4.1
accelerate>=0.34.0,<=1.5.2
peft>=0.14.0,<=0.15.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
......
......@@ -16,9 +16,7 @@ import asyncio
import os
from contextlib import asynccontextmanager
from functools import partial
from typing import Optional
from typing_extensions import Annotated
from typing import Annotated, Optional
from ..chat import ChatModel
from ..extras.constants import EngineName
......
......@@ -18,11 +18,12 @@ import json
import os
import re
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Optional
from ..data import Role as DataRole
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
......@@ -55,7 +56,7 @@ if is_requests_available():
if TYPE_CHECKING:
from ..chat import ChatModel
from ..data.mm_plugin import ImageInput
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
......@@ -71,7 +72,14 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
) -> tuple[
list[dict[str, str]],
Optional[str],
Optional[str],
Optional[list["ImageInput"]],
Optional[list["VideoInput"]],
Optional[list["AudioInput"]],
]:
if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
......@@ -87,7 +95,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
images = []
images, videos, audios = [], [], []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
......@@ -106,7 +114,7 @@ def _process_request(
for input_item in message.content:
if input_item.type == "text":
text_content += input_item.text
else:
elif input_item.type == "image_url":
text_content += IMAGE_PLACEHOLDER
image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
......@@ -117,6 +125,28 @@ def _process_request(
image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB"))
elif input_item.type == "video_url":
text_content += VIDEO_PLACEHOLDER
video_url = input_item.video_url.url
if os.path.isfile(video_url): # local file
video_stream = open(video_url, "rb")
else: # web uri
video_stream = requests.get(video_url, stream=True).raw
videos.append(video_stream)
elif input_item.type == "audio_url":
text_content += AUDIO_PLACEHOLDER
audio_url = input_item.audio_url.url
if os.path.isfile(audio_url): # local file
audio_stream = open(audio_url, "rb")
else: # web uri
audio_stream = requests.get(audio_url, stream=True).raw
audios.append(audio_stream)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
)
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
else:
......@@ -131,7 +161,7 @@ def _process_request(
else:
tools = None
return input_messages, system, tools, images or None
return input_messages, system, tools, images or None, videos or None, audios or None
def _create_stream_chat_completion_chunk(
......@@ -150,12 +180,14 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request)
input_messages, system, tools, images, videos, audios = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
images,
videos,
audios,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
......@@ -201,7 +233,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request)
input_messages, system, tools, images, videos, audios = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
......@@ -216,6 +248,8 @@ async def create_stream_chat_completion_response(
system,
tools,
images,
videos,
audios,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
......
......@@ -13,14 +13,14 @@
# limitations under the License.
import json
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:
def dictify(data: "BaseModel") -> dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
......
......@@ -14,7 +14,7 @@
import time
from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Literal
......@@ -45,7 +45,7 @@ class ModelCard(BaseModel):
class ModelList(BaseModel):
object: Literal["list"] = "list"
data: List[ModelCard] = []
data: list[ModelCard] = []
class Function(BaseModel):
......@@ -56,7 +56,7 @@ class Function(BaseModel):
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
parameters: dict[str, Any]
class FunctionAvailable(BaseModel):
......@@ -70,38 +70,41 @@ class FunctionCall(BaseModel):
function: Function
class ImageURL(BaseModel):
class URL(BaseModel):
url: str
detail: Literal["auto", "low", "high"] = "auto"
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
image_url: Optional[URL] = None
video_url: Optional[URL] = None
audio_url: Optional[URL] = None
class ChatMessage(BaseModel):
role: Role
content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None
content: Optional[Union[str, list[MultimodalInputItem]]] = None
tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: Optional[List[FunctionAvailable]] = None
messages: list[ChatMessage]
tools: Optional[list[FunctionAvailable]] = None
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop: Optional[Union[str, list[str]]] = None
stream: bool = False
......@@ -128,7 +131,7 @@ class ChatCompletionResponse(BaseModel):
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
choices: list[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage
......@@ -137,12 +140,12 @@ class ChatCompletionStreamResponse(BaseModel):
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionStreamResponseChoice]
choices: list[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel):
model: str
messages: List[str]
messages: list[str]
max_length: Optional[int] = None
......@@ -150,4 +153,4 @@ class ScoreEvaluationResponse(BaseModel):
id: str
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]
scores: list[float]
......@@ -13,8 +13,9 @@
# limitations under the License.
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
if TYPE_CHECKING:
......@@ -36,8 +37,7 @@ class Response:
class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
r"""Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
......@@ -47,7 +47,7 @@ class BaseEngine(ABC):
tokenizer: "PreTrainedTokenizer"
can_generate: bool
template: "Template"
generating_args: Dict[str, Any]
generating_args: dict[str, Any]
@abstractmethod
def __init__(
......@@ -57,50 +57,42 @@ class BaseEngine(ABC):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
r"""
Initializes an inference engine.
"""
r"""Initialize an inference engine."""
...
@abstractmethod
async def chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
) -> list["Response"]:
r"""Get a list of responses of the chat model."""
...
@abstractmethod
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None,
images: Optional[list["ImageInput"]] = None,
videos: Optional[list["VideoInput"]] = None,
audios: Optional[list["AudioInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
r"""Get the response token-by-token of the chat model."""
...
@abstractmethod
async def get_scores(
self,
batch_input: List[str],
batch_input: list[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
) -> list[float]:
r"""Get a list of scores of the reward model."""
...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment