Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict from typing import Any
import fire import fire
import torch import torch
...@@ -37,14 +37,14 @@ CONFIG_NAME = "config.json" ...@@ -37,14 +37,14 @@ CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str: def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() qwen_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f: with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
for key in f.keys(): for key in f.keys():
qwen_state_dict[key] = f.get_tensor(key) qwen_state_dict[key] = f.get_tensor(key)
llama_state_dict: Dict[str, torch.Tensor] = OrderedDict() llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
torch_dtype = None torch_dtype = None
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"): for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
if torch_dtype is None: if torch_dtype is None:
...@@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso ...@@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def save_config(input_dir: str, output_dir: str, torch_dtype: str): def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f: with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f) qwen_config_dict: dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict() llama2_config_dict: dict[str, Any] = OrderedDict()
llama2_config_dict["architectures"] = ["LlamaForCausalLM"] llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict["hidden_act"] = "silu" llama2_config_dict["hidden_act"] = "silu"
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"] llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
...@@ -147,8 +147,8 @@ def llamafy_qwen( ...@@ -147,8 +147,8 @@ def llamafy_qwen(
shard_size: str = "2GB", shard_size: str = "2GB",
save_safetensors: bool = False, save_safetensors: bool = False,
): ):
r""" r"""Convert the Qwen models in the same format as LLaMA2.
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
""" """
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
if __name__ == "__main__":
vision_config = Llama4VisionConfig(
hidden_size=1408,
image_size=336,
intermediate_size=5632,
num_attention_heads=16,
num_hidden_layers=4,
vision_output_dim=4096,
)
text_config = Llama4TextConfig(
hidden_size=512,
intermediate_size=1024,
intermediate_size_mlp=1024,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=2,
head_dim=512 // 8,
num_local_experts=2,
)
config = Llama4Config(vision_config=vision_config, text_config=text_config)
model = Llama4ForConditionalGeneration._from_config(config)
model.save_pretrained("tiny-llama4")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import time
import fire
from datasets import load_dataset
try:
import jieba
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from rouge_chinese import Rouge
jieba.setLogLevel(logging.CRITICAL)
jieba.initialize()
except ImportError:
print("Please install llamafactory with `pip install -e .[metrics]`.")
raise
def compute_metrics(sample):
hypothesis = list(jieba.cut(sample["predict"]))
reference = list(jieba.cut(sample["label"]))
bleu_score = sentence_bleu(
[list(sample["label"])],
list(sample["predict"]),
smoothing_function=SmoothingFunction().method3,
)
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
else:
rouge = Rouge()
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
result = scores[0]
metric_result = {}
for k, v in result.items():
metric_result[k] = round(v["f"] * 100, 4)
metric_result["bleu-4"] = round(bleu_score * 100, 4)
return metric_result
def main(filename: str):
start_time = time.time()
dataset = load_dataset("json", data_files=filename, split="train")
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
score_dict = dataset.to_dict()
average_score = {}
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
print(f"{task}: {sum(scores) / len(scores):.4f}")
average_score[task] = sum(scores) / len(scores)
with open("predictions_score.json", "w", encoding="utf-8") as f:
json.dump(average_score, f, indent=4)
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to predictions_score.json")
if __name__ == "__main__":
fire.Fire(main)
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
import fire import fire
import torch import torch
...@@ -44,11 +44,11 @@ def block_expansion( ...@@ -44,11 +44,11 @@ def block_expansion(
shard_size: str = "5GB", shard_size: str = "5GB",
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
""" """
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
num_layers = getattr(config, "num_hidden_layers") num_layers = getattr(config, "num_hidden_layers")
if num_layers % num_expand != 0: if num_layers % num_expand != 0:
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.") raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
...@@ -70,7 +70,7 @@ def block_expansion( ...@@ -70,7 +70,7 @@ def block_expansion(
split = num_layers // num_expand split = num_layers // num_expand
layer_cnt = 0 layer_cnt = 0
state_dict = model.state_dict() state_dict = model.state_dict()
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict() output_state_dict: dict[str, torch.Tensor] = OrderedDict()
for i in range(num_layers): for i in range(num_layers):
for key, value in state_dict.items(): for key, value in state_dict.items():
if f".{i:d}." in key: if f".{i:d}." in key:
......
...@@ -38,8 +38,8 @@ def quantize_loftq( ...@@ -38,8 +38,8 @@ def quantize_loftq(
lora_target: tuple = ("q_proj", "v_proj"), lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ).
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
""" """
if isinstance(lora_target, str): if isinstance(lora_target, str):
...@@ -72,7 +72,7 @@ def quantize_loftq( ...@@ -72,7 +72,7 @@ def quantize_loftq(
print(f"Adapter weights saved in {loftq_dir}") print(f"Adapter weights saved in {loftq_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}") print(f"Model weights saved in {output_dir}")
......
...@@ -37,8 +37,8 @@ def quantize_pissa( ...@@ -37,8 +37,8 @@ def quantize_pissa(
lora_target: tuple = ("q_proj", "v_proj"), lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True, save_safetensors: bool = True,
): ):
r""" r"""Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA).
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
""" """
if isinstance(lora_target, str): if isinstance(lora_target, str):
...@@ -67,7 +67,7 @@ def quantize_pissa( ...@@ -67,7 +67,7 @@ def quantize_pissa(
print(f"Adapter weights saved in {pissa_dir}") print(f"Adapter weights saved in {pissa_dir}")
# Save base model # Save base model
base_model: "PreTrainedModel" = peft_model.unload() base_model: PreTrainedModel = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
print(f"Model weights saved in {output_dir}") print(f"Model weights saved in {output_dir}")
......
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import fire
from peft import PeftModel
from transformers import AutoModel, AutoProcessor, AutoTokenizer, Qwen2_5OmniThinkerForConditionalGeneration
def merge_lora(
base_model_path: str,
lora_checkpoint_path: str,
extra_file: str = "spk_dict.pt",
submodule_name: str = "thinker",
save_path: str = "./merged_model_checkpoint",
):
"""Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
For a specified submodule, and save the final merged model along with its configurations.
Args:
base_model_path (str): Path to the original model directory.
lora_checkpoint_path (str): Path to the directory containing LoRA weights.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
submodule_name (str): Name of the submodule to merge (default: "thinker").
save_path (str): Directory where the merged model and configurations will be saved.
"""
# 1. Load the original model, tokenizer, and processor
model = AutoModel.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
try:
processor = AutoProcessor.from_pretrained(base_model_path)
except Exception:
print("Processor configuration not found, skipping processor load.")
processor = None
print("Successfully loaded the original model, tokenizer, and processor (if available).")
# 2. Extract the submodule to be merged (e.g., model.thinker)
if not hasattr(model, submodule_name):
raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
base_submodule = getattr(model, submodule_name)
print(f"Successfully extracted submodule: {submodule_name}.")
# 3. Load the LoRA weights onto the extracted submodule
lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
print("LoRA weights loaded successfully.")
# 4. Merge the LoRA weights into the submodule and unload the LoRA modules
merged_submodule = lora_model.merge_and_unload()
print("LoRA weights merged successfully.")
# 5. Replace the original submodule with the merged submodule in the model
setattr(model, submodule_name, merged_submodule)
# 6. Save the final merged model along with the tokenizer and processor configuration
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
if processor is not None:
processor.save_pretrained(save_path)
print(f"Merged model and configuration saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
shutil.copy(source_file, target_file)
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
else:
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
def save_full_model(
saved_thinker_path: str,
base_model_path: str,
save_path: str,
extra_file: str = "spk_dict.pt",
):
"""Load the saved thinker module and the original model, replace the thinker in the original model.
Then save the complete model along with its tokenizer and processor configuration.
Args:
saved_thinker_path (str): Path to the saved thinker weights.
base_model_path (str): Directory path of the original model.
save_path (str): Directory where the final complete model will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
"""
# Load the thinker module
thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(saved_thinker_path, device_map="cpu")
# Load the original model
base_model = AutoModel.from_pretrained(base_model_path, device_map="cpu")
# Replace the thinker module in the original model
base_model.thinker = thinker
# Load the processor and tokenizer
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
# Save the complete model along with its configurations
base_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
processor.save_pretrained(save_path)
print(f"Complete model, tokenizer, and processor configuration have been saved to {save_path}.")
source_file = os.path.join(base_model_path, extra_file)
target_file = os.path.join(save_path, extra_file)
if os.path.exists(source_file):
shutil.copy(source_file, target_file)
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
else:
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
if __name__ == "__main__":
fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})
...@@ -29,8 +29,8 @@ def calculate_flops( ...@@ -29,8 +29,8 @@ def calculate_flops(
seq_length: int = 512, seq_length: int = 512,
flash_attn: str = "auto", flash_attn: str = "auto",
): ):
r""" r"""Calculate the flops of pre-trained models.
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
""" """
with get_accelerator().device(0): with get_accelerator().device(0):
......
...@@ -45,8 +45,8 @@ def calculate_lr( ...@@ -45,8 +45,8 @@ def calculate_lr(
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate, is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False, packing: bool = False,
): ):
r""" r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16 python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
""" """
...@@ -89,9 +89,8 @@ def calculate_lr( ...@@ -89,9 +89,8 @@ def calculate_lr(
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size) lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral_or_gemma else lr lr = lr / 6.0 if is_mistral_or_gemma else lr
print( print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format( f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} "
lr, valid_ratio * 100, token_batch_size f"and effective token batch size {token_batch_size:.2f}"
)
) )
......
...@@ -34,9 +34,7 @@ def compute_model_flops( ...@@ -34,9 +34,7 @@ def compute_model_flops(
include_recompute: bool = False, include_recompute: bool = False,
include_flashattn: bool = False, include_flashattn: bool = False,
) -> int: ) -> int:
r""" r"""Calculate the FLOPs of model per forward/backward pass."""
Calculates the FLOPs of model per forward/backward pass.
"""
config = AutoConfig.from_pretrained(model_name_or_path) config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None) hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None) vocab_size = getattr(config, "vocab_size", None)
...@@ -86,9 +84,7 @@ def compute_model_flops( ...@@ -86,9 +84,7 @@ def compute_model_flops(
def compute_device_flops(world_size: int) -> float: def compute_device_flops(world_size: int) -> float:
r""" r"""Calculate the FLOPs of the device capability per second."""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name() device_name = torch.cuda.get_device_name()
if "H100" in device_name or "H800" in device_name: if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * world_size return 989 * 1e12 * world_size
...@@ -114,8 +110,8 @@ def calculate_mfu( ...@@ -114,8 +110,8 @@ def calculate_mfu(
liger_kernel: bool = False, liger_kernel: bool = False,
unsloth_gc: bool = False, unsloth_gc: bool = False,
) -> float: ) -> float:
r""" r"""Calculate MFU for given model and hyper-params.
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024 Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
""" """
args = { args = {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional, Sequence from typing import Any, Literal, Optional
import fire import fire
import torch import torch
...@@ -30,16 +30,12 @@ from llamafactory.model import load_model, load_tokenizer ...@@ -30,16 +30,12 @@ from llamafactory.model import load_model, load_tokenizer
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r""" r"""Data collator for pairwise data."""
Data collator for pairwise data.
"""
train_on_prompt: bool = False train_on_prompt: bool = False
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
r""" r"""Pad batched data to the longest sequence in the batch."""
Pads batched data to the longest sequence in the batch.
"""
chosen_features = [] chosen_features = []
for feature in features: for feature in features:
chosen_features.append( chosen_features.append(
...@@ -68,8 +64,8 @@ def calculate_ppl( ...@@ -68,8 +64,8 @@ def calculate_ppl(
max_samples: Optional[int] = None, max_samples: Optional[int] = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r""" r"""Calculate the ppl on the dataset of the pre-trained models.
Calculates the ppl on the dataset of the pre-trained models.
Usage: export CUDA_VISIBLE_DEVICES=0 Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
""" """
...@@ -111,17 +107,17 @@ def calculate_ppl( ...@@ -111,17 +107,17 @@ def calculate_ppl(
criterion = torch.nn.CrossEntropyLoss(reduction="none") criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0 total_ppl = 0
perplexities = [] perplexities = []
batch: Dict[str, "torch.Tensor"] batch: dict[str, torch.Tensor]
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(dataloader, desc="Computing perplexities"): for batch in tqdm(dataloader, desc="Computing perplexities"):
batch = batch.to(model.device) batch = batch.to(model.device)
outputs = model(**batch) outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :] shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:] shift_labels: torch.Tensor = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1) flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1) flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels) token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1) token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item() total_ppl += sentence_logps.exp().sum().item()
......
...@@ -29,8 +29,8 @@ def length_cdf( ...@@ -29,8 +29,8 @@ def length_cdf(
template: str = "default", template: str = "default",
interval: int = 1000, interval: int = 1000,
): ):
r""" r"""Calculate the distribution of the input lengths in the dataset.
Calculates the distribution of the input lengths in the dataset.
Usage: export CUDA_VISIBLE_DEVICES=0 Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
""" """
......
...@@ -52,11 +52,11 @@ def vllm_infer( ...@@ -52,11 +52,11 @@ def vllm_infer(
image_max_pixels: int = 768 * 768, image_max_pixels: int = 768 * 768,
image_min_pixels: int = 32 * 32, image_min_pixels: int = 32 * 32,
): ):
r""" r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
""" """
check_version("vllm>=0.4.3,<=0.7.3") check_version("vllm>=0.4.3,<=0.8.2")
if pipeline_parallel_size > get_device_count(): if pipeline_parallel_size > get_device_count():
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
...@@ -92,8 +92,20 @@ def vllm_infer( ...@@ -92,8 +92,20 @@ def vllm_infer(
multi_modal_data = { multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images( "image": template_obj.mm_plugin._regularize_images(
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
) )["images"]
}
elif sample["videos"]:
multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos(
sample["videos"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)["videos"]
} }
elif sample["audios"]:
audio_data = template_obj.mm_plugin._regularize_audios(
sample["audios"],
sampling_rate=16000,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else: else:
multi_modal_data = None multi_modal_data = None
...@@ -131,7 +143,7 @@ def vllm_infer( ...@@ -131,7 +143,7 @@ def vllm_infer(
"enable_lora": model_args.adapter_name_or_path is not None, "enable_lora": model_args.adapter_name_or_path is not None,
} }
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin": if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2} engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
if isinstance(model_args.vllm_config, dict): if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config) engine_args.update(model_args.vllm_config)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import os import os
import re import re
from typing import List
from setuptools import find_packages, setup from setuptools import find_packages, setup
...@@ -27,14 +26,14 @@ def get_version() -> str: ...@@ -27,14 +26,14 @@ def get_version() -> str:
return version return version
def get_requires() -> List[str]: def get_requires() -> list[str]:
with open("requirements.txt", encoding="utf-8") as f: with open("requirements.txt", encoding="utf-8") as f:
file_content = f.read() file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines return lines
def get_console_scripts() -> List[str]: def get_console_scripts() -> list[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"] console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]: if os.getenv("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "y", "1"]:
console_scripts.append("lmf = llamafactory.cli:main") console_scripts.append("lmf = llamafactory.cli:main")
...@@ -47,14 +46,15 @@ extra_require = { ...@@ -47,14 +46,15 @@ extra_require = {
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"], "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.4"], "deepspeed": ["deepspeed>=0.10.0,<=0.16.4"],
"liger-kernel": ["liger-kernel"], "liger-kernel": ["liger-kernel>=0.5.5"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"], "hqq": ["hqq"],
"eetq": ["eetq"], "eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.7.3"], "vllm": ["vllm>=0.4.3,<=0.8.2"],
"sglang": ["sglang[srt]>=0.4.4", "transformers==4.48.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"apollo": ["apollo-torch"], "apollo": ["apollo-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
...@@ -69,6 +69,7 @@ extra_require = { ...@@ -69,6 +69,7 @@ extra_require = {
"msgpack", "msgpack",
"referencing", "referencing",
"jsonschema_specifications", "jsonschema_specifications",
"transformers==4.48.3",
], ],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"openmind": ["openmind"], "openmind": ["openmind"],
...@@ -82,11 +83,11 @@ def main(): ...@@ -82,11 +83,11 @@ def main():
name="llamafactory", name="llamafactory",
version=get_version(), version=get_version(),
author="hiyouga", author="hiyouga",
author_email="hiyouga AT buaa.edu.cn", author_email="hiyouga@buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework", description="Unified Efficient Fine-Tuning of 100+ LLMs",
long_description=open("README.md", encoding="utf-8").read(), long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"], keywords=["AI", "LLM", "GPT", "ChatGPT", "Llama", "Transformer", "DeepSeek", "Pytorch"],
license="Apache 2.0 License", license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory", url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"}, package_dir={"": "src"},
......
...@@ -12,18 +12,17 @@ ...@@ -12,18 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
r""" r"""Efficient fine-tuning of large language models.
Efficient fine-tuning of large language models.
Level: Level:
api, webui > chat, eval, train > data, model > hparams > extras api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0 transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.2.0 datasets>=2.16.0,<=3.4.1
accelerate>=0.34.0,<=1.2.1 accelerate>=0.34.0,<=1.5.2
peft>=0.11.1,<=0.12.0 peft>=0.14.0,<=0.15.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
......
...@@ -16,9 +16,7 @@ import asyncio ...@@ -16,9 +16,7 @@ import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import Optional from typing import Annotated, Optional
from typing_extensions import Annotated
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.constants import EngineName from ..extras.constants import EngineName
......
...@@ -18,11 +18,12 @@ import json ...@@ -18,11 +18,12 @@ import json
import os import os
import re import re
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Optional
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras import logging from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import is_env_enabled from ..extras.misc import is_env_enabled
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
...@@ -55,7 +56,7 @@ if is_requests_available(): ...@@ -55,7 +56,7 @@ if is_requests_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from ..chat import ChatModel from ..chat import ChatModel
from ..data.mm_plugin import ImageInput from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
...@@ -71,7 +72,14 @@ ROLE_MAPPING = { ...@@ -71,7 +72,14 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]: ) -> tuple[
list[dict[str, str]],
Optional[str],
Optional[str],
Optional[list["ImageInput"]],
Optional[list["VideoInput"]],
Optional[list["AudioInput"]],
]:
if is_env_enabled("API_VERBOSE", "1"): if is_env_enabled("API_VERBOSE", "1"):
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
...@@ -87,7 +95,7 @@ def _process_request( ...@@ -87,7 +95,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
images = [] images, videos, audios = [], [], []
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
...@@ -106,7 +114,7 @@ def _process_request( ...@@ -106,7 +114,7 @@ def _process_request(
for input_item in message.content: for input_item in message.content:
if input_item.type == "text": if input_item.type == "text":
text_content += input_item.text text_content += input_item.text
else: elif input_item.type == "image_url":
text_content += IMAGE_PLACEHOLDER text_content += IMAGE_PLACEHOLDER
image_url = input_item.image_url.url image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
...@@ -117,6 +125,28 @@ def _process_request( ...@@ -117,6 +125,28 @@ def _process_request(
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
images.append(Image.open(image_stream).convert("RGB")) images.append(Image.open(image_stream).convert("RGB"))
elif input_item.type == "video_url":
text_content += VIDEO_PLACEHOLDER
video_url = input_item.video_url.url
if os.path.isfile(video_url): # local file
video_stream = open(video_url, "rb")
else: # web uri
video_stream = requests.get(video_url, stream=True).raw
videos.append(video_stream)
elif input_item.type == "audio_url":
text_content += AUDIO_PLACEHOLDER
audio_url = input_item.audio_url.url
if os.path.isfile(audio_url): # local file
audio_stream = open(audio_url, "rb")
else: # web uri
audio_stream = requests.get(audio_url, stream=True).raw
audios.append(audio_stream)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
)
input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
else: else:
...@@ -131,7 +161,7 @@ def _process_request( ...@@ -131,7 +161,7 @@ def _process_request(
else: else:
tools = None tools = None
return input_messages, system, tools, images or None return input_messages, system, tools, images or None, videos or None, audios or None
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
...@@ -150,12 +180,14 @@ async def create_chat_completion_response( ...@@ -150,12 +180,14 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request) input_messages, system, tools, images, videos, audios = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
images, images,
videos,
audios,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
...@@ -201,7 +233,7 @@ async def create_stream_chat_completion_response( ...@@ -201,7 +233,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images = _process_request(request) input_messages, system, tools, images, videos, audios = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
...@@ -216,6 +248,8 @@ async def create_stream_chat_completion_response( ...@@ -216,6 +248,8 @@ async def create_stream_chat_completion_response(
system, system,
tools, tools,
images, images,
videos,
audios,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
......
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import json import json
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic import BaseModel from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]: def dictify(data: "BaseModel") -> dict[str, Any]:
try: # pydantic v2 try: # pydantic v2
return data.model_dump(exclude_unset=True) return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1 except AttributeError: # pydantic v1
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
...@@ -45,7 +45,7 @@ class ModelCard(BaseModel): ...@@ -45,7 +45,7 @@ class ModelCard(BaseModel):
class ModelList(BaseModel): class ModelList(BaseModel):
object: Literal["list"] = "list" object: Literal["list"] = "list"
data: List[ModelCard] = [] data: list[ModelCard] = []
class Function(BaseModel): class Function(BaseModel):
...@@ -56,7 +56,7 @@ class Function(BaseModel): ...@@ -56,7 +56,7 @@ class Function(BaseModel):
class FunctionDefinition(BaseModel): class FunctionDefinition(BaseModel):
name: str name: str
description: str description: str
parameters: Dict[str, Any] parameters: dict[str, Any]
class FunctionAvailable(BaseModel): class FunctionAvailable(BaseModel):
...@@ -70,38 +70,41 @@ class FunctionCall(BaseModel): ...@@ -70,38 +70,41 @@ class FunctionCall(BaseModel):
function: Function function: Function
class ImageURL(BaseModel): class URL(BaseModel):
url: str url: str
detail: Literal["auto", "low", "high"] = "auto"
class MultimodalInputItem(BaseModel): class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"] type: Literal["text", "image_url", "video_url", "audio_url"]
text: Optional[str] = None text: Optional[str] = None
image_url: Optional[ImageURL] = None image_url: Optional[URL] = None
video_url: Optional[URL] = None
audio_url: Optional[URL] = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: Optional[Union[str, List[MultimodalInputItem]]] = None content: Optional[Union[str, list[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None tool_calls: Optional[list[FunctionCall]] = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: list[ChatMessage]
tools: Optional[List[FunctionAvailable]] = None tools: Optional[list[FunctionAvailable]] = None
do_sample: Optional[bool] = None do_sample: Optional[bool] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
n: int = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, list[str]]] = None
stream: bool = False stream: bool = False
...@@ -128,7 +131,7 @@ class ChatCompletionResponse(BaseModel): ...@@ -128,7 +131,7 @@ class ChatCompletionResponse(BaseModel):
object: Literal["chat.completion"] = "chat.completion" object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: list[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage usage: ChatCompletionResponseUsage
...@@ -137,12 +140,12 @@ class ChatCompletionStreamResponse(BaseModel): ...@@ -137,12 +140,12 @@ class ChatCompletionStreamResponse(BaseModel):
object: Literal["chat.completion.chunk"] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionStreamResponseChoice] choices: list[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel): class ScoreEvaluationRequest(BaseModel):
model: str model: str
messages: List[str] messages: list[str]
max_length: Optional[int] = None max_length: Optional[int] = None
...@@ -150,4 +153,4 @@ class ScoreEvaluationResponse(BaseModel): ...@@ -150,4 +153,4 @@ class ScoreEvaluationResponse(BaseModel):
id: str id: str
object: Literal["score.evaluation"] = "score.evaluation" object: Literal["score.evaluation"] = "score.evaluation"
model: str model: str
scores: List[float] scores: list[float]
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Literal, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -36,8 +37,7 @@ class Response: ...@@ -36,8 +37,7 @@ class Response:
class BaseEngine(ABC): class BaseEngine(ABC):
r""" r"""Base class for inference engine of chat models.
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores(). Must implements async methods: chat(), stream_chat() and get_scores().
""" """
...@@ -47,7 +47,7 @@ class BaseEngine(ABC): ...@@ -47,7 +47,7 @@ class BaseEngine(ABC):
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
can_generate: bool can_generate: bool
template: "Template" template: "Template"
generating_args: Dict[str, Any] generating_args: dict[str, Any]
@abstractmethod @abstractmethod
def __init__( def __init__(
...@@ -57,50 +57,42 @@ class BaseEngine(ABC): ...@@ -57,50 +57,42 @@ class BaseEngine(ABC):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
r""" r"""Initialize an inference engine."""
Initializes an inference engine.
"""
... ...
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> list["Response"]:
r""" r"""Get a list of responses of the chat model."""
Gets a list of responses of the chat model.
"""
... ...
@abstractmethod @abstractmethod
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""Get the response token-by-token of the chat model."""
Gets the response token-by-token of the chat model.
"""
... ...
@abstractmethod @abstractmethod
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: list[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> list[float]:
r""" r"""Get a list of scores of the reward model."""
Gets a list of scores of the reward model.
"""
... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment