Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# Freeze Configuration
peft_config:
name: freeze
freeze_trainable_layers: 2 # Train the last 2 layers
freeze_trainable_modules: all # In these layers, train specific modules
freeze_extra_modules: null # Extra modules to train (e.g. embed_tokens, lm_head)
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_freeze
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 2.0e-5
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto
dist_config:
name: deepspeed
config_file: examples/deepspeed/ds_z3_config.json
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/Qwen3-0.6B-deepspeed
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128
model: Qwen/Qwen3-4B
peft_config:
name: lora
adapter_name_or_path: ./outputs/test_lora
export_dir: ./merge_lora_model
export_size: 5
infer_dtype: auto
model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
# Quantization Config
quant_config:
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
quantization_bit: 4 # choice: 8/4(bnb)
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_quantization
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128
[build-system] [build-system]
requires = ["setuptools>=61.0"] requires = ["hatchling"]
build-backend = "setuptools.build_meta" build-backend = "hatchling.build"
[project] [project]
name = "llamafactory" name = "llamafactory"
dynamic = [ dynamic = ["version"]
"version", description = "Unified Efficient Fine-Tuning of 100+ LLMs"
"dependencies", readme = "README.md"
"optional-dependencies", license = "Apache-2.0"
"requires-python", requires-python = ">=3.11.0"
"scripts", authors = [
"authors", { name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
"description",
"readme",
"license",
"keywords",
"classifiers"
] ]
keywords = [
"AI",
"LLM",
"GPT",
"ChatGPT",
"Llama",
"Transformer",
"DeepSeek",
"Pytorch"
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
# core deps
"torch>=2.4.0",
"torchvision>=0.19.0",
"torchaudio>=2.4.0",
"transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1",
"trl>=0.18.0,<=0.24.0",
"torchdata>=0.10.0,<=0.11.0",
# gui
"gradio>=4.38.0,<=5.50.0",
"matplotlib>=3.7.0",
"tyro<0.9.0",
# ops
"einops",
"numpy",
"pandas",
"scipy",
# model and tokenizer
"sentencepiece",
"tiktoken",
"modelscope",
"hf-transfer",
"safetensors",
# python
"av",
"fire",
"omegaconf",
"packaging",
"protobuf",
"pyyaml",
"pydantic",
# api
"uvicorn",
"fastapi",
"sse-starlette"
]
[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "build"]
metrics = ["nltk", "jieba", "rouge-chinese"]
deepspeed = ["deepspeed>=0.10.0,<=0.16.9"]
[project.scripts]
llamafactory-cli = "llamafactory.cli:main"
lmf = "llamafactory.cli:main"
[project.urls]
Homepage = "https://github.com/hiyouga/LLaMA-Factory"
Repository = "https://github.com/hiyouga/LLaMA-Factory"
[tool.hatch.build.targets.wheel]
packages = ["src/llamafactory"]
[tool.hatch.version]
path = "src/llamafactory/extras/env.py"
pattern = "VERSION = \"(?P<version>[^\"]+)\""
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py311"
line-length = 119 line-length = 119
indent-width = 4 indent-width = 4
...@@ -30,6 +108,8 @@ ignore = [ ...@@ -30,6 +108,8 @@ ignore = [
"E501", # line too long "E501", # line too long
"E731", # lambda function "E731", # lambda function
"E741", # ambiguous var name "E741", # ambiguous var name
"UP007", # no upgrade union
"UP045", # no upgrade optional
"D100", # no doc public module "D100", # no doc public module
"D101", # no doc public class "D101", # no doc public class
"D102", # no doc public method "D102", # no doc public method
...@@ -73,23 +153,3 @@ indent-style = "space" ...@@ -73,23 +153,3 @@ indent-style = "space"
docstring-code-format = true docstring-code-format = true
skip-magic-trailing-comma = false skip-magic-trailing-comma = false
line-ending = "auto" line-ending = "auto"
[tool.uv]
conflicts = [
[
{ extra = "torch-npu" },
{ extra = "aqlm" },
],
[
{ extra = "torch-npu" },
{ extra = "liger-kernel" },
],
[
{ extra = "torch-npu" },
{ extra = "vllm" },
],
[
{ extra = "sglang" },
{ extra = "minicpm_v" },
],
]
# core deps # core deps
python>=3.9,<=3.11
torch>=2.00,<=2.60
torchvision>=0.15.0,<=0.21.0
transformers>=4.49.0,<=4.52.4,!=4.52.0; sys_platform != 'darwin' transformers>=4.49.0,<=4.52.4,!=4.52.0; sys_platform != 'darwin'
transformers>=4.49.0,<=4.51.3,!=4.52.0; sys_platform == 'darwin' transformers>=4.49.0,<=4.51.3,!=4.52.0; sys_platform == 'darwin'
datasets>=2.16.0,<=3.6.0 datasets>=2.16.0,<=3.6.0
accelerate>=1.3.0,<=1.7.0 accelerate>=0.3.4,<=1.7.0
peft>=0.14.0,<=0.15.2 peft>=0.14.0,<=0.15.2
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
tokenizers>=0.19.0,<=0.21.1 tokenizers>=0.19.0,<=0.21.1
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import Any
import fire
import torch
from peft import PeftModel
from torch.utils.data import Dataset
from transformers import DataCollatorForSeq2Seq, Qwen2_5_VLProcessor
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_model, load_tokenizer
from llamafactory.train.callbacks import LogCallback
from llamafactory.train.sft.trainer import CustomSeq2SeqTrainer
class DummyDataset(Dataset):
def __init__(self, size: int = 1000, seq_length: int = 1024, processor: Qwen2_5_VLProcessor = None):
self.size = size
self.seq_length = seq_length
self.vocab_size = 32768
self.processor = processor
image_token_num = 18 * 18 // (2 * 2)
image_t = 2
self.text_seqlen = seq_length // 4 # 25% text
video_seq_length = self.seq_length - self.text_seqlen - image_t * image_token_num
video_t = video_seq_length // image_token_num
self.image_size = [18 * 18 * image_t, 1176]
self.image_grid_thw = torch.tensor([[1, 18, 18]] * image_t, dtype=torch.long)
self.image_seqlen = image_t * image_token_num
self.video_size = [18 * 18 * video_t, 1176]
self.video_grid_thw = torch.tensor([[video_t, 18, 18]], dtype=torch.long)
self.video_seqlen = video_t * image_token_num
def __len__(self):
return self.size
def __getitem__(self, index: int):
input_ids = torch.randint(low=0, high=self.vocab_size, size=(self.seq_length,))
input_ids[: self.image_seqlen] = self.processor.image_token_id
input_ids[self.image_seqlen : self.image_seqlen + self.video_seqlen] = self.processor.video_token_id
attention_mask = torch.ones((self.seq_length,), dtype=torch.long)
labels = input_ids.clone()
labels[: self.image_seqlen + self.video_seqlen] = IGNORE_INDEX
pixel_values = torch.rand(self.image_size, dtype=torch.float32)
pixel_values_videos = torch.rand(self.video_size, dtype=torch.float32)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": self.image_grid_thw,
"video_grid_thw": self.video_grid_thw,
}
@dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def __post_init__(self):
if isinstance(self.model, PeftModel):
self.model = self.model.base_model.model
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni
elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0
else:
self.get_rope_func = None
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_pixel_values = [feature.pop("pixel_values") for feature in features]
batch_pixel_values_videos = [feature.pop("pixel_values_videos") for feature in features]
batch_image_grid_thw = [feature.pop("image_grid_thw") for feature in features]
batch_video_grid_thw = [feature.pop("video_grid_thw") for feature in features]
batch: dict[str, torch.Tensor] = super().__call__(features)
batch["pixel_values"] = torch.cat(batch_pixel_values, dim=0)
batch["pixel_values_videos"] = torch.cat(batch_pixel_values_videos, dim=0)
batch["image_grid_thw"] = torch.cat(batch_image_grid_thw, dim=0)
batch["video_grid_thw"] = torch.cat(batch_video_grid_thw, dim=0)
if self.get_rope_func is not None:
rope_index_kwargs = {
"input_ids": batch["input_ids"],
"image_grid_thw": batch["image_grid_thw"],
"video_grid_thw": batch["video_grid_thw"],
"attention_mask": (batch["attention_mask"] >= 1).float(),
}
batch["position_ids"], batch["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
if "position_ids" not in batch or batch["position_ids"].dim() != 3:
raise ValueError("Qwen2VL requires 3D position ids for mrope.")
return batch
def bench_qwen(
model_name_or_path: str = "Qwen/Qwen2-VL-7B-Instruct",
batch_size: int = 1,
seq_length: int = 2048,
liger_kernel: bool = False,
deepspeed_stage: int = 3,
):
os.environ["LLAMABOARD_ENABLED"] = "true"
os.environ["LLAMABOARD_WORKDIR"] = "output/dummy_dir"
args = {
"model_name_or_path": model_name_or_path,
"enable_liger_kernel": liger_kernel,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
"dataset": "alpaca_en_demo",
"template": "qwen2_vl",
"cutoff_len": seq_length,
"output_dir": "output/dummy_dir",
"logging_steps": 10,
"save_strategy": "no",
"save_only_model": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": 1000,
"bf16": True,
"include_num_input_tokens_seen": True,
"report_to": "none",
}
if deepspeed_stage in [2, 3]:
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
trainset = DummyDataset(size=100000, seq_length=seq_length, processor=tokenizer_module["processor"])
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = MultiModalDataCollatorForSeq2Seq(
tokenizer=tokenizer, model=model, pad_to_multiple_of=8, label_pad_token_id=IGNORE_INDEX
)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=[LogCallback()],
train_dataset=trainset,
**tokenizer_module,
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if __name__ == "__main__":
fire.Fire(bench_qwen)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer, Qwen3Config, Qwen3ForCausalLM
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
config = Qwen3Config(
hidden_size=1408,
image_size=336,
intermediate_size=5632,
num_attention_heads=16,
num_hidden_layers=4,
vision_output_dim=4096,
)
model = Qwen3ForCausalLM.from_config(config)
model.save_pretrained("tiny-qwen3")
tokenizer.save_pretrained("tiny-qwen3")
model.push_to_hub("llamafactory/tiny-random-qwen3")
tokenizer.push_to_hub("llamafactory/tiny-random-qwen3")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a HuggingFace model to DCP checkpoint format.
Usage:
python scripts/hf2dcp.py convert --hf_path=/path/to/hf --dcp_path=/path/to/dcp
Arguments:
hf_path: Path to the HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
import fire
import torch
import torch.distributed.checkpoint as dcp
from transformers import AutoModelForCausalLM
def convert(hf_path: str, dcp_path: str) -> None:
"""Convert HF model weights to DCP.
Args:
hf_path: HuggingFace model directory.
dcp_path: Output path (directory) for DCP checkpoint.
"""
if not hf_path or not dcp_path:
raise ValueError("Both 'hf_path' and 'dcp_path' are required.")
print(f"Loading HF model from {hf_path}...")
model = AutoModelForCausalLM.from_pretrained(hf_path, device_map="cpu", torch_dtype=torch.bfloat16)
print(f"Saving to DCP format at {dcp_path}...")
dcp.save(model.state_dict(), checkpoint_id=dcp_path)
print("Done!")
def help() -> None:
"""Show help message."""
print(__doc__)
if __name__ == "__main__":
fire.Fire({"convert": convert, "help": help, "--convert": convert})
# Copyright 2025 the ROLL team and the LlamaFactory team.
#
# This code is modified from the ROLL library.
# https://github.com/alibaba/ROLL/blob/main/mcore_adapter/tools/convert.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import fire
import torch
from mcore_adapter.models.converter.post_converter import convert_checkpoint_to_hf, convert_checkpoint_to_mca
from mcore_adapter.training_args import DistributingParallelArguments
from mcore_adapter.utils import get_logger
from transformers import AutoConfig
logger = get_logger(__name__)
def convert_mca_to_hf(
checkpoint_path: str,
output_path: str = "./output",
bf16: bool = False,
fp16: bool = False,
convert_model_max_length: int | None = None,
):
"""Convert megatron checkpoint to HuggingFace format.
Args:
checkpoint_path: Path to the checkpoint to convert
output_path: Path to save the converted checkpoint
bf16: Use bfloat16 precision
fp16: Use float16 precision
convert_model_max_length: Change the model_max_length in hf config.json
"""
if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.")
torch_dtype = None
if bf16:
torch_dtype = torch.bfloat16
elif fp16:
torch_dtype = torch.float16
convert_checkpoint_to_hf(checkpoint_path, output_path, torch_dtype=torch_dtype)
if convert_model_max_length is not None:
config = AutoConfig.from_pretrained(output_path, trust_remote_code=True)
config.model_max_length = convert_model_max_length
config.save_pretrained(output_path)
def convert(
checkpoint_path: str,
output_path: str = "./output",
bf16: bool = False,
fp16: bool = False,
convert_model_max_length: int | None = None,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: int | None = None,
):
"""Convert checkpoint between MCA and HuggingFace formats.
Args:
checkpoint_path: Path to the checkpoint to convert
output_path: Path to save the converted checkpoint
bf16: Use bfloat16 precision
fp16: Use float16 precision
convert_model_max_length: Change the model_max_length in hf config.json
tensor_model_parallel_size: Tensor model parallel size
pipeline_model_parallel_size: Pipeline model parallel size
expert_model_parallel_size: Expert model parallel size
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
"""
if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.")
mca_config_path = os.path.join(checkpoint_path, "mca_config.json")
from_mca = os.path.exists(mca_config_path)
if not from_mca:
dist_args = DistributingParallelArguments(
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
expert_model_parallel_size=expert_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
)
convert_checkpoint_to_mca(
checkpoint_path,
output_path,
dist_args,
bf16=bf16,
fp16=fp16,
)
else:
convert_mca_to_hf(
checkpoint_path=checkpoint_path,
output_path=output_path,
bf16=bf16,
fp16=fp16,
convert_model_max_length=convert_model_max_length,
)
def main():
fire.Fire(convert)
if __name__ == "__main__":
main()
...@@ -29,33 +29,30 @@ import shutil ...@@ -29,33 +29,30 @@ import shutil
import fire import fire
from peft import PeftModel from peft import PeftModel
from transformers import ( from transformers import AutoConfig, AutoModelForTextToWaveform, AutoProcessor
AutoProcessor, from transformers.utils import cached_file
Qwen2_5OmniForConditionalGeneration, # type: ignore
Qwen2_5OmniThinkerForConditionalGeneration,
)
def merge_lora( def merge_lora(
base_model_path: str, model_path: str,
lora_checkpoint_path: str, lora_path: str,
save_path: str = "./merged_model_checkpoint",
extra_file: str = "spk_dict.pt", extra_file: str = "spk_dict.pt",
submodule_name: str = "thinker", submodule_name: str = "thinker",
save_path: str = "./merged_model_checkpoint",
): ):
"""Load the original model, merge the LoRA weights. """Load the original model, merge the LoRA weights.
For a specified submodule, and save the final merged model along with its configurations. For a specified submodule, and save the final merged model along with its configurations.
Args: Args:
base_model_path (str): Path to the original model directory. model_path (str): Path to the original model directory.
lora_checkpoint_path (str): Path to the directory containing LoRA weights. lora_path (str): Path to the directory containing LoRA weights.
save_path (str): Directory where the merged model and configurations will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt"). extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
submodule_name (str): Name of the submodule to merge (default: "thinker"). submodule_name (str): Name of the submodule to merge (default: "thinker").
save_path (str): Directory where the merged model and configurations will be saved.
""" """
# 1. Load the original model # 1. Load the original model
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(base_model_path, torch_dtype="auto", device_map="cpu") model = AutoModelForTextToWaveform.from_pretrained(model_path, torch_dtype="auto", device_map="cpu")
print("Successfully loaded the original model.") print("Successfully loaded the original model.")
# 2. Extract the submodule to be merged (e.g., model.thinker) # 2. Extract the submodule to be merged (e.g., model.thinker)
...@@ -66,13 +63,13 @@ def merge_lora( ...@@ -66,13 +63,13 @@ def merge_lora(
print(f"Successfully extracted submodule: {submodule_name}.") print(f"Successfully extracted submodule: {submodule_name}.")
# 3. Load the LoRA weights onto the extracted submodule # 3. Load the LoRA weights onto the extracted submodule
lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path) lora_model = PeftModel.from_pretrained(base_submodule, lora_path)
processor = AutoProcessor.from_pretrained(lora_checkpoint_path) processor = AutoProcessor.from_pretrained(lora_path)
print("LoRA weights and processor loaded successfully.") print("Successfully loaded LoRA weights and processor.")
# 4. Merge the LoRA weights into the submodule and unload the LoRA modules # 4. Merge the LoRA weights into the submodule and unload the LoRA modules
merged_submodule = lora_model.merge_and_unload() merged_submodule = lora_model.merge_and_unload()
print("LoRA weights merged successfully.") print("Successfully merged LoRA weights.")
# 5. Replace the original submodule with the merged submodule in the model # 5. Replace the original submodule with the merged submodule in the model
setattr(model, submodule_name, merged_submodule) setattr(model, submodule_name, merged_submodule)
...@@ -80,20 +77,19 @@ def merge_lora( ...@@ -80,20 +77,19 @@ def merge_lora(
# 6. Save the final merged model along with the tokenizer and processor configuration # 6. Save the final merged model along with the tokenizer and processor configuration
model.save_pretrained(save_path) model.save_pretrained(save_path)
processor.save_pretrained(save_path) processor.save_pretrained(save_path)
print(f"Merged model and tokenizer saved to {save_path}.") print(f"Merged model and processor saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file) try:
target_file = os.path.join(save_path, extra_file) source_file = cached_file(path_or_repo_id=model_path, filename=extra_file)
if os.path.exists(source_file): shutil.copy(source_file, os.path.join(save_path, extra_file))
shutil.copy(source_file, target_file) print(f"File '{extra_file}' copied from {model_path} to {save_path}.")
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.") except Exception:
else: print(f"File '{extra_file}' not found in {model_path}, skipping copy.")
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
def save_full_model( def save_full_model(
saved_thinker_path: str, model_path: str,
base_model_path: str, thinker_path: str,
save_path: str = "./merged_model_checkpoint", save_path: str = "./merged_model_checkpoint",
extra_file: str = "spk_dict.pt", extra_file: str = "spk_dict.pt",
): ):
...@@ -102,34 +98,42 @@ def save_full_model( ...@@ -102,34 +98,42 @@ def save_full_model(
Then save the complete model along with its tokenizer and processor configuration. Then save the complete model along with its tokenizer and processor configuration.
Args: Args:
saved_thinker_path (str): Path to the saved thinker weights. model_path (str): Directory path of the original model.
base_model_path (str): Directory path of the original model. thinker_path (str): Path to the saved thinker weights.
save_path (str): Directory where the merged model and configurations will be saved. save_path (str): Directory where the merged model and configurations will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt"). extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
""" """
# 1. Load the saved thinker module and the original model # 1. Load the saved thinker module and the original model
thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained( config = AutoConfig.from_pretrained(model_path)
saved_thinker_path, torch_dtype="auto", device_map="cpu" if getattr(config, "model_type") == "qwen2_5_omni":
) from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration # type: ignore
base_model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
base_model_path, torch_dtype="auto", device_map="cpu" ThinkerClass = Qwen2_5OmniThinkerForConditionalGeneration
) elif getattr(config, "model_type") == "qwen3_omni_moe":
from transformers.models.qwen3_omni_moe import Qwen3OmniMoeThinkerForConditionalGeneration # type: ignore
ThinkerClass = Qwen3OmniMoeThinkerForConditionalGeneration
else:
raise ValueError(f"Unsupported model type: {getattr(config, 'model_type')}.")
thinker = ThinkerClass.from_pretrained(thinker_path, torch_dtype="auto", device_map="cpu")
base_model = AutoModelForTextToWaveform.from_pretrained(model_path, torch_dtype="auto", device_map="cpu")
base_model.thinker = thinker base_model.thinker = thinker
processor = AutoProcessor.from_pretrained(thinker_path)
print("Successfully loaded model weights and processor.")
# 2. Save the complete model along with its tokenizer and processor configuration # 2. Save the complete model along with its tokenizer and processor configuration
processor = AutoProcessor.from_pretrained(saved_thinker_path)
base_model.save_pretrained(save_path) base_model.save_pretrained(save_path)
processor.save_pretrained(save_path) processor.save_pretrained(save_path)
print(f"Merged model and processor saved to {save_path}.") print(f"Merged model and processor saved to {save_path}.")
# 3. Copy the extra file from the base model directory to the save_path # 3. Copy the extra file from the base model directory to the save_path
source_file = os.path.join(base_model_path, extra_file) try:
target_file = os.path.join(save_path, extra_file) source_file = cached_file(path_or_repo_id=model_path, filename=extra_file)
if os.path.exists(source_file): shutil.copy(source_file, os.path.join(save_path, extra_file))
shutil.copy(source_file, target_file) print(f"File '{extra_file}' copied from {model_path} to {save_path}.")
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.") except Exception:
else: print(f"File '{extra_file}' not found in {model_path}, skipping copy.")
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, Optional from typing import Any, Literal
import fire import fire
import torch import torch
...@@ -61,7 +61,7 @@ def calculate_ppl( ...@@ -61,7 +61,7 @@ def calculate_ppl(
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 2048, cutoff_len: int = 2048,
max_samples: Optional[int] = None, max_samples: int | None = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r"""Calculate the ppl on the dataset of the pre-trained models. r"""Calculate the ppl on the dataset of the pre-trained models.
......
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
import gc import gc
import json import json
from typing import Optional import time
import av
import fire import fire
from datasets import load_dataset
from eval_bleu_rouge import compute_metrics
from tqdm import tqdm from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
...@@ -33,6 +36,14 @@ if is_vllm_available(): ...@@ -33,6 +36,14 @@ if is_vllm_available():
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
def _need_video_kwargs(template):
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
if any(t in template for t in NEEDED_TEMPLATE):
return True
return False
def vllm_infer( def vllm_infer(
model_name_or_path: str, model_name_or_path: str,
adapter_name_or_path: str = None, adapter_name_or_path: str = None,
...@@ -40,18 +51,19 @@ def vllm_infer( ...@@ -40,18 +51,19 @@ def vllm_infer(
dataset_dir: str = "data", dataset_dir: str = "data",
template: str = "default", template: str = "default",
cutoff_len: int = 2048, cutoff_len: int = 2048,
max_samples: Optional[int] = None, max_samples: int | None = None,
vllm_config: str = "{}", vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl", save_name: str = "generated_predictions.jsonl",
matrix_save_name: str = None,
temperature: float = 0.95, temperature: float = 0.95,
top_p: float = 0.7, top_p: float = 0.7,
top_k: int = 50, top_k: int = 50,
max_new_tokens: int = 1024, max_new_tokens: int = 1024,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
default_system: Optional[str] = None, default_system: str | None = None,
enable_thinking: bool = True, enable_thinking: bool = True,
seed: Optional[int] = None, seed: int | None = None,
pipeline_parallel_size: int = 1, pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768, image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32, image_min_pixels: int = 32 * 32,
...@@ -109,6 +121,7 @@ def vllm_infer( ...@@ -109,6 +121,7 @@ def vllm_infer(
if isinstance(model_args.vllm_config, dict): if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config) engine_args.update(model_args.vllm_config)
model_preparation_start_time = time.time()
llm = LLM(**engine_args) llm = LLM(**engine_args)
# load datasets # load datasets
...@@ -132,7 +145,9 @@ def vllm_infer( ...@@ -132,7 +145,9 @@ def vllm_infer(
# Store all results in these lists # Store all results in these lists
all_prompts, all_preds, all_labels = [], [], [] all_prompts, all_preds, all_labels = [], [], []
need_video_kwargs = _need_video_kwargs(template)
model_predict_start_time = time.time()
# Add batch process to avoid the issue of too many files opened # Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"): for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], [] vllm_inputs, prompts, labels = [], [], []
...@@ -147,6 +162,7 @@ def vllm_infer( ...@@ -147,6 +162,7 @@ def vllm_infer(
)["images"] )["images"]
} }
elif batch["videos"][j] is not None: elif batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j] video = batch["videos"][j]
multi_modal_data = { multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos( "video": template_obj.mm_plugin._regularize_videos(
...@@ -157,6 +173,25 @@ def vllm_infer( ...@@ -157,6 +173,25 @@ def vllm_infer(
video_maxlen=video_maxlen, video_maxlen=video_maxlen,
)["videos"] )["videos"]
} }
if need_video_kwargs:
container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
video_stream, video_fps, video_maxlen
)
total_frames = video_stream.frames
video_metadata_kwargs = {
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
"do_sample_frames": False,
"total_num_frames": total_frames,
}
video_metadata = dict(
fps=video_fps,
frames_indices=sampling_indices,
total_num_frames=total_frames,
video_backend="opencv",
)
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None: elif batch["audios"][j] is not None:
audio = batch["audios"][j] audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios( audio_data = template_obj.mm_plugin._regularize_audios(
...@@ -167,7 +202,11 @@ def vllm_infer( ...@@ -167,7 +202,11 @@ def vllm_infer(
else: else:
multi_modal_data = None multi_modal_data = None
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}) vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data)
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens)) prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append( labels.append(
tokenizer.decode( tokenizer.decode(
...@@ -185,6 +224,7 @@ def vllm_infer( ...@@ -185,6 +224,7 @@ def vllm_infer(
all_labels.extend(labels) all_labels.extend(labels)
gc.collect() gc.collect()
model_predict_end_time = time.time()
# Write all results at once outside the loop # Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(all_prompts, all_preds, all_labels): for text, pred, label in zip(all_prompts, all_preds, all_labels):
...@@ -194,6 +234,49 @@ def vllm_infer( ...@@ -194,6 +234,49 @@ def vllm_infer(
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.") print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
print("*" * 70) print("*" * 70)
# Write all matrix results when matrix_save_name is not None,
# The result matrix is referencing src.llamafactory.train.sft.workflow.run_sft # 127~132
# trainer.save_metrics("predict", predict_results.metrics)
#
# {
# "predict_bleu-4": 4.349975,
# "predict_model_preparation_time": 0.0128,
# "predict_rouge-1": 21.873359375,
# "predict_rouge-2": 4.144340625,
# "predict_rouge-l": 10.83949375,
# "predict_runtime": 131.664,
# "predict_samples_per_second": 0.076,
# "predict_steps_per_second": 0.008
# }
#
if matrix_save_name is not None:
predict_time = model_predict_end_time - model_predict_start_time
preparation_time = model_predict_start_time - model_preparation_start_time
start_time = time.time()
dataset = load_dataset("json", data_files=save_name, split="train")
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
score_dict = dataset.to_dict()
average_score = {}
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
score = sum(scores) / len(scores) if scores else 0.0
print(f"predict_{task}: {score:.4f}")
average_score["predict_" + task] = score
average_score["predict_model_preparation_time"] = preparation_time
average_score["predict_runtime"] = predict_time
num_steps = len(range(0, len(train_dataset), batch_size))
average_score["predict_samples_per_second"] = len(dataset) / predict_time if predict_time > 0 else 0.0
average_score["predict_steps_per_second"] = num_steps / predict_time if predict_time > 0 else 0.0
with open(matrix_save_name, "w", encoding="utf-8") as f:
json.dump(average_score, f, indent=4)
print("*" * 70)
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to {matrix_save_name}.")
print("*" * 70)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(vllm_infer) fire.Fire(vllm_infer)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from setuptools import find_packages, setup
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = {
"torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
"torch-npu": ["torch-npu==2.5.1", "torchvision==0.20.1", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.9"],
"liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.24.0", "gptqmodel>=2.0.0"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.9.1"],
"sglang": ["sglang[srt]>=0.4.5", "transformers==4.51.1"],
"galore": ["galore-torch"],
"apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"minicpm_v": [
"soundfile",
"torchvision",
"torchaudio",
"vector_quantize_pytorch",
"vocos",
"msgpack",
"referencing",
"jsonschema_specifications",
],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
"dev": ["pre-commit", "ruff", "pytest", "build"],
}
def main():
setup(
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga@buaa.edu.cn",
description="Unified Efficient Fine-Tuning of 100+ LLMs",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.9.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": get_console_scripts()},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
if __name__ == "__main__":
main()
...@@ -16,7 +16,7 @@ import asyncio ...@@ -16,7 +16,7 @@ import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Annotated, Optional from typing import Annotated
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.constants import EngineName from ..extras.constants import EngineName
...@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": ...@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
api_key = os.getenv("API_KEY") api_key = os.getenv("API_KEY")
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key): if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
......
...@@ -26,7 +26,7 @@ from ..extras import logging ...@@ -26,7 +26,7 @@ from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import is_env_enabled from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
from .protocol import ( from .protocol import (
ChatCompletionMessage, ChatCompletionMessage,
ChatCompletionResponse, ChatCompletionResponse,
...@@ -121,8 +121,10 @@ def _process_request( ...@@ -121,8 +121,10 @@ def _process_request(
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1])) image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(image_url): # local file elif os.path.isfile(image_url): # local file
check_lfi_path(image_url)
image_stream = open(image_url, "rb") image_stream = open(image_url, "rb")
else: # web uri else: # web uri
check_ssrf_url(image_url)
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB")) images.append(Image.open(image_stream).convert("RGB"))
...@@ -132,8 +134,10 @@ def _process_request( ...@@ -132,8 +134,10 @@ def _process_request(
if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1])) video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(video_url): # local file elif os.path.isfile(video_url): # local file
check_lfi_path(video_url)
video_stream = video_url video_stream = video_url
else: # web uri else: # web uri
check_ssrf_url(video_url)
video_stream = requests.get(video_url, stream=True).raw video_stream = requests.get(video_url, stream=True).raw
videos.append(video_stream) videos.append(video_stream)
...@@ -143,8 +147,10 @@ def _process_request( ...@@ -143,8 +147,10 @@ def _process_request(
if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1])) audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(audio_url): # local file elif os.path.isfile(audio_url): # local file
check_lfi_path(audio_url)
audio_stream = audio_url audio_stream = audio_url
else: # web uri else: # web uri
check_ssrf_url(audio_url)
audio_stream = requests.get(audio_url, stream=True).raw audio_stream = requests.get(audio_url, stream=True).raw
audios.append(audio_stream) audios.append(audio_stream)
......
...@@ -12,14 +12,29 @@ ...@@ -12,14 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ipaddress
import json import json
import os
import socket
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available
if is_fastapi_available():
from fastapi import HTTPException, status
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic import BaseModel from pydantic import BaseModel
SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
def dictify(data: "BaseModel") -> dict[str, Any]: def dictify(data: "BaseModel") -> dict[str, Any]:
try: # pydantic v2 try: # pydantic v2
return data.model_dump(exclude_unset=True) return data.model_dump(exclude_unset=True)
...@@ -32,3 +47,50 @@ def jsonify(data: "BaseModel") -> str: ...@@ -32,3 +47,50 @@ def jsonify(data: "BaseModel") -> str:
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1 except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False) return data.json(exclude_unset=True, ensure_ascii=False)
def check_lfi_path(path: str) -> None:
"""Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
if not ALLOW_LOCAL_FILES:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
try:
os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
real_path = os.path.realpath(path)
safe_path = os.path.realpath(SAFE_MEDIA_PATH)
if not real_path.startswith(safe_path):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
def check_ssrf_url(url: str) -> None:
"""Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
try:
parsed_url = urlparse(url)
if parsed_url.scheme not in ["http", "https"]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
hostname = parsed_url.hostname
if not hostname:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
ip_info = socket.getaddrinfo(hostname, parsed_url.port)
ip_address_str = ip_info[0][4][0]
ip = ipaddress.ip_address(ip_address_str)
if not ip.is_global:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access to private or reserved IP addresses is not allowed.",
)
except socket.gaierror:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Optional, Union from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique @unique
...@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel): ...@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
class FunctionAvailable(BaseModel): class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function" type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None function: FunctionDefinition | None = None
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
...@@ -77,35 +76,35 @@ class URL(BaseModel): ...@@ -77,35 +76,35 @@ class URL(BaseModel):
class MultimodalInputItem(BaseModel): class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url", "video_url", "audio_url"] type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None text: str | None = None
image_url: Optional[URL] = None image_url: URL | None = None
video_url: Optional[URL] = None video_url: URL | None = None
audio_url: Optional[URL] = None audio_url: URL | None = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[Union[str, list[MultimodalInputItem]]] = None content: str | list[MultimodalInputItem] | None = None
tool_calls: Optional[list[FunctionCall]] = None tool_calls: list[FunctionCall] | None = None
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Role | None = None
content: Optional[str] = None content: str | None = None
tool_calls: Optional[list[FunctionCall]] = None tool_calls: list[FunctionCall] | None = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: list[ChatMessage] messages: list[ChatMessage]
tools: Optional[list[FunctionAvailable]] = None tools: list[FunctionAvailable] | None = None
do_sample: Optional[bool] = None do_sample: bool | None = None
temperature: Optional[float] = None temperature: float | None = None
top_p: Optional[float] = None top_p: float | None = None
n: int = 1 n: int = 1
presence_penalty: Optional[float] = None presence_penalty: float | None = None
max_tokens: Optional[int] = None max_tokens: int | None = None
stop: Optional[Union[str, list[str]]] = None stop: str | list[str] | None = None
stream: bool = False stream: bool = False
...@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel): ...@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
class ChatCompletionStreamResponseChoice(BaseModel): class ChatCompletionStreamResponseChoice(BaseModel):
index: int index: int
delta: ChatCompletionMessage delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None finish_reason: Finish | None = None
class ChatCompletionResponseUsage(BaseModel): class ChatCompletionResponseUsage(BaseModel):
...@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel): ...@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
class ScoreEvaluationRequest(BaseModel): class ScoreEvaluationRequest(BaseModel):
model: str model: str
messages: list[str] messages: list[str]
max_length: Optional[int] = None max_length: int | None = None
class ScoreEvaluationResponse(BaseModel): class ScoreEvaluationResponse(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment