Commit 12d5cbac authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1780 canceled with stages
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections import OrderedDict
from typing import Any, Dict
import fire
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
from transformers.utils import check_min_version
try:
check_min_version("4.34.0")
except Exception:
raise ValueError("Please upgrade `transformers` to 4.34.0")
CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
for key in f.keys():
qwen_state_dict[key] = f.get_tensor(key)
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
torch_dtype = None
for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
if torch_dtype is None:
torch_dtype = value.dtype
if "wte" in key:
llama2_state_dict["model.embed_tokens.weight"] = value
elif "ln_f" in key:
llama2_state_dict["model.norm.weight"] = value
else:
key = key.replace("transformer.h", "model.layers")
if "attn.c_attn" in key:
proj_size = value.size(0) // 3
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[
proj_size : 2 * proj_size, ...
]
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2 * proj_size :, ...]
elif "attn.c_proj" in key:
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = torch.zeros_like(
value[:, 0]
).squeeze()
elif "ln_1" in key:
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
elif "ln_2" in key:
llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
elif "mlp.w1" in key:
llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
elif "mlp.w2" in key:
llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
elif "mlp.c_proj" in key:
llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
elif "lm_head" in key:
llama2_state_dict[key] = value
else:
raise KeyError("Unable to process key {}".format(key))
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
return str(torch_dtype).replace("torch.", "")
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
qwen_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict: Dict[str, Any] = OrderedDict()
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict["hidden_act"] = "silu"
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"]
llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"] // 2
llama2_config_dict["max_position_embeddings"] = qwen_config_dict["max_position_embeddings"]
llama2_config_dict["model_type"] = "llama"
llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"]
llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"]
llama2_config_dict["num_key_value_heads"] = qwen_config_dict["hidden_size"] // qwen_config_dict["kv_channels"]
llama2_config_dict["pretraining_tp"] = 1
llama2_config_dict["rms_norm_eps"] = qwen_config_dict["layer_norm_epsilon"]
llama2_config_dict["rope_scaling"] = None
llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"]
llama2_config_dict["torch_dtype"] = torch_dtype
llama2_config_dict["transformers_version"] = "4.34.0"
llama2_config_dict["use_cache"] = True
llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"]
llama2_config_dict["attention_bias"] = True
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_qwen(
input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = False,
):
r"""
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
"""
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
raise print("Output dir already exists", e)
torch_dtype = save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir, torch_dtype)
if __name__ == "__main__":
fire.Fire(llamafy_qwen)
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import fire
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedModel
def quantize_loftq(
model_name_or_path: str,
output_dir: str,
loftq_bits: int = 4,
loftq_iter: int = 4,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,
lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if isinstance(lora_target, str):
lora_target = [name.strip() for name in lora_target.split(",")]
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=True,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=lora_target,
init_lora_weights="loftq",
loftq_config=loftq_config,
)
# Init LoftQ model
print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
peft_model = get_peft_model(model, lora_config)
loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir))
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(loftq_dir))
print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits))
if __name__ == "__main__":
fire.Fire(quantize_loftq)
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import fire
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedModel
def quantize_pissa(
model_name_or_path: str,
output_dir: str,
pissa_iter: int = 16,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,
lora_target: tuple = ("q_proj", "v_proj"),
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if isinstance(lora_target, str):
lora_target = [name.strip() for name in lora_target.split(",")]
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=lora_target,
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
)
# Init PiSSA model
peft_model = get_peft_model(model, lora_config)
pissa_dir = os.path.join(output_dir, "pissa_init")
# Save PiSSA model
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir))
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(pissa_dir))
print("finetuning_type: lora")
print("pissa_init: false")
print("pissa_convert: true")
print("- and optionally with:")
print("quantization_bit: 4")
if __name__ == "__main__":
fire.Fire(quantize_pissa)
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import Sequence
from openai import OpenAI
from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
grade_to_score = {"A": 4, "B": 3, "C": 2}
total_score, total_hour = 0, 0
for grade, hour in zip(grades, hours):
total_score += grade_to_score[grade] * hour
total_hour += hour
return round(total_score / total_hour, 2)
def main():
client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
)
tools = [
{
"type": "function",
"function": {
"name": "calculate_gpa",
"description": "Calculate the Grade Point Average (GPA) based on grades and credit hours",
"parameters": {
"type": "object",
"properties": {
"grades": {"type": "array", "items": {"type": "string"}, "description": "The grades"},
"hours": {"type": "array", "items": {"type": "integer"}, "description": "The credit hours"},
},
"required": ["grades", "hours"],
},
},
}
]
tool_map = {"calculate_gpa": calculate_gpa}
messages = []
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
if result.choices[0].message.tool_calls is None:
raise ValueError("Cannot retrieve function call from the response.")
messages.append(result.choices[0].message)
tool_call = result.choices[0].message.tool_calls[0].function
print(tool_call)
# Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa')
name, arguments = tool_call.name, json.loads(tool_call.arguments)
tool_result = tool_map[name](**arguments)
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
print(result.choices[0].message.content)
# Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42.
if __name__ == "__main__":
main()
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from typing import List
from setuptools import find_packages, setup
def get_version() -> str:
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
file_content = f.read()
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
(version,) = re.findall(pattern, file_content)
return version
def get_requires() -> List[str]:
with open("requirements.txt", "r", encoding="utf-8") as f:
file_content = f.read()
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
return lines
def get_console_scripts() -> List[str]:
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
console_scripts.append("lmf = llamafactory.cli:main")
return console_scripts
extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
"liger-kernel": ["liger-kernel"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
"eetq": ["eetq"],
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<=0.6.3"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
"dev": ["ruff", "pytest"],
}
def main():
setup(
name="llamafactory",
version=get_version(),
author="hiyouga",
author_email="hiyouga" "@" "buaa.edu.cn",
description="Easy-to-use LLM fine-tuning framework",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
license="Apache 2.0 License",
url="https://github.com/hiyouga/LLaMA-Factory",
package_dir={"": "src"},
packages=find_packages("src"),
python_requires=">=3.8.0",
install_requires=get_requires(),
extras_require=extra_require,
entry_points={"console_scripts": get_console_scripts()},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)
if __name__ == "__main__":
main()
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import uvicorn
from llamafactory.api.app import create_app
from llamafactory.chat import ChatModel
def main():
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
uvicorn.run(app, host=api_host, port=api_port)
if __name__ == "__main__":
main()
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2,<=4.45.2
datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1,<=0.34.2
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.45.2
packing:
transformers>=4.41.2,<=4.45.2
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
"""
from .extras.env import VERSION
__version__ = VERSION
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from contextlib import asynccontextmanager
from functools import partial
from typing import Optional
from typing_extensions import Annotated
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import (
create_chat_completion_response,
create_score_evaluation_response,
create_stream_chat_completion_response,
)
from .protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ModelCard,
ModelList,
ScoreEvaluationRequest,
ScoreEvaluationResponse,
)
if is_fastapi_available():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
async def sweeper() -> None:
while True:
torch_gc()
await asyncio.sleep(300)
@asynccontextmanager
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface":
asyncio.create_task(sweeper())
yield
torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
api_key = os.environ.get("API_KEY", None)
security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
@app.get(
"/v1/models",
response_model=ModelList,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def list_models():
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card])
@app.post(
"/v1/chat/completions",
response_model=ChatCompletionResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if request.stream:
generate = create_stream_chat_completion_response(request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return await create_chat_completion_response(request, chat_model)
@app.post(
"/v1/score/evaluation",
response_model=ScoreEvaluationResponse,
status_code=status.HTTP_200_OK,
dependencies=[Depends(verify_api_key)],
)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
return await create_score_evaluation_response(request, chat_model)
return app
def run_api() -> None:
chat_model = ChatModel()
app = create_app(chat_model)
api_host = os.environ.get("API_HOST", "0.0.0.0")
api_port = int(os.environ.get("API_PORT", "8000"))
print("Visit http://localhost:{}/docs for API document.".format(api_port))
uvicorn.run(app, host=api_host, port=api_port)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import json
import os
import re
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole
from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify
from .protocol import (
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseUsage,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
Finish,
Function,
FunctionCall,
Role,
ScoreEvaluationResponse,
)
if is_fastapi_available():
from fastapi import HTTPException, status
if is_pillow_available():
from PIL import Image
if is_requests_available():
import requests
if TYPE_CHECKING:
from ..chat import ChatModel
from ..data.mm_plugin import ImageInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
logger = get_logger(__name__)
ROLE_MAPPING = {
Role.USER: DataRole.USER.value,
Role.ASSISTANT: DataRole.ASSISTANT.value,
Role.SYSTEM: DataRole.SYSTEM.value,
Role.FUNCTION: DataRole.FUNCTION.value,
Role.TOOL: DataRole.OBSERVATION.value,
}
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content
else:
system = None
if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
image = None
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
tool_calls = [
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
for tool_call in message.tool_calls
]
content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
if input_item.type == "text":
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(image_url): # local file
image_stream = open(image_url, "rb")
else: # web uri
image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_stream).convert("RGB")
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = None
return input_messages, system, tools, image
def _create_stream_chat_completion_chunk(
completion_id: str,
model: str,
delta: "ChatCompletionMessage",
index: Optional[int] = 0,
finish_reason: Optional["Finish"] = None,
) -> str:
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
return jsonify(chunk)
async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools, image = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
stop=request.stop,
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.engine.template.extract_tool(response.response_text)
else:
result = response.response_text
if isinstance(result, list):
tool_calls = []
for tool in result:
function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length + response_length,
)
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools, image = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
if request.n > 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
)
async for new_token in chat_model.astream_chat(
input_messages,
system,
tools,
image,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
max_new_tokens=request.max_tokens,
stop=request.stop,
):
if len(new_token) != 0:
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
)
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
yield "[DONE]"
async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex)
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import TYPE_CHECKING, Any, Dict
if TYPE_CHECKING:
from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]:
try: # pydantic v2
return data.model_dump(exclude_unset=True)
except AttributeError: # pydantic v1
return data.dict(exclude_unset=True)
def jsonify(data: "BaseModel") -> str:
try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except AttributeError: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Literal
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
@unique
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
TOOL = "tool_calls"
class ModelCard(BaseModel):
id: str
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: Literal["owner"] = "owner"
class ModelList(BaseModel):
object: Literal["list"] = "list"
data: List[ModelCard] = []
class Function(BaseModel):
name: str
arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel):
id: str
type: Literal["function"] = "function"
function: Function
class ImageURL(BaseModel):
url: str
class MultimodalInputItem(BaseModel):
type: Literal["text", "image_url"]
text: Optional[str] = None
image_url: Optional[ImageURL] = None
class ChatMessage(BaseModel):
role: Role
content: Optional[Union[str, List[MultimodalInputItem]]] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None
content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
tools: Optional[List[FunctionAvailable]] = None
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: int = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatCompletionMessage
finish_reason: Finish
class ChatCompletionStreamResponseChoice(BaseModel):
index: int
delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None
class ChatCompletionResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: ChatCompletionResponseUsage
class ChatCompletionStreamResponse(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionStreamResponseChoice]
class ScoreEvaluationRequest(BaseModel):
model: str
messages: List[str]
max_length: Optional[int] = None
class ScoreEvaluationResponse(BaseModel):
id: str
object: Literal["score.evaluation"] = "score.evaluation"
model: str
scores: List[float]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_engine import BaseEngine
from .chat_model import ChatModel
__all__ = ["BaseEngine", "ChatModel"]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer"
can_generate: bool
template: "Template"
generating_args: Dict[str, Any]
@abstractmethod
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
r"""
Initializes an inference engine.
"""
...
@abstractmethod
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
...
@abstractmethod
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
...
@abstractmethod
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
...
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from ..data.mm_plugin import ImageInput, VideoInput
from .base_engine import BaseEngine, Response
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
)
return task.result()
async def achat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True:
try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
yield task.result()
except StopAsyncIteration:
break
async def astream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token
def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
async def aget_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs)
def run_chat() -> None:
if os.name != "nt":
try:
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
chat_model = ChatModel()
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
torch_gc()
print("History has been removed.")
continue
messages.append({"role": "user", "content": query})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append({"role": "assistant", "content": response})
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import concurrent.futures
import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from trl import PreTrainedModelWrapper
from ..data import Template
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class HuggingfaceEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict()
try:
asyncio.get_event_loop()
except RuntimeError:
logger.warning("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
@staticmethod
def _process_args(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None:
mm_input_dict.update({"images": [image], "imglens": [1]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
if video is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
)
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None:
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
generating_args = generating_args.copy()
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature if temperature is not None else generating_args["temperature"],
top_p=top_p if top_p is not None else generating_args["top_p"],
top_k=top_k if top_k is not None else generating_args["top_k"],
num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty
if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
generating_args["do_sample"] = True
generating_args["temperature"] = generating_args["temperature"] or 1.0
if not generating_args["temperature"]:
generating_args["do_sample"] = False
if not generating_args["do_sample"]:
generating_args.pop("temperature", None)
generating_args.pop("top_p", None)
if max_length:
generating_args.pop("max_new_tokens", None)
generating_args["max_length"] = max_length
if max_new_tokens:
generating_args.pop("max_length", None)
generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length
@staticmethod
@torch.inference_mode()
def _chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
)
generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(
Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length",
)
)
return results
@staticmethod
@torch.inference_mode()
def _stream_chat(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
template: "Template",
generating_args: Dict[str, Any],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
@staticmethod
@torch.inference_mode()
def _get_scores(
model: "PreTrainedModelWrapper",
tokenizer: "PreTrainedTokenizer",
batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]:
max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda")
inputs: Dict[str, "torch.Tensor"] = tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=False,
).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return scores
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
video,
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args)
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
loop = asyncio.get_running_loop()
input_args = (
self.model,
self.tokenizer,
self.processor,
self.template,
self.generating_args,
messages,
system,
tools,
image,
video,
input_kwargs,
)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args)
while True:
try:
yield await loop.run_in_executor(pool, stream)
except StopAsyncIteration:
break
@override
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
if TYPE_CHECKING:
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class VllmEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"download_dir": model_args.cache_dir,
"dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
self.lora_request = None
async def _generate(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
if image is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = SamplingParams(
n=num_return_sequences,
repetition_penalty=(
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,
skip_special_tokens=True,
)
if image is not None: # add image features
if not isinstance(image, (str, ImageObject)):
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
if isinstance(image, str):
image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image}
else:
multi_modal_data = None
result_generator = self.model.generate(
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params,
request_id=request_id,
lora_request=self.lora_request,
)
return result_generator
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for request_output in generator:
final_output = request_output
results = []
for output in final_output.outputs:
results.append(
Response(
response_text=output.text,
response_length=len(output.token_ids),
prompt_length=len(final_output.prompt_token_ids),
finish_reason=output.finish_reason,
)
)
return results
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text
@override
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import subprocess
import sys
from enum import Enum, unique
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras.env import VERSION, print_env
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli eval -h: evaluate models |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "-" * 70
)
WELCOME = (
"-" * 58
+ "\n"
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
logger = get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
def main():
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
if command == Command.API:
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).format(
nnodes=os.environ.get("NNODES", "1"),
node_rank=os.environ.get("RANK", "0"),
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
),
shell=True,
)
sys.exit(process.returncode)
else:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError("Unknown command: {}.".format(command))
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .collator import (
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"Template",
"get_template_and_fix_tokenizer",
]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras.logging import get_logger
from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(
images: Sequence["ImageInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
if len(images) == 0:
return None
images = images[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
images[i] = os.path.join(data_args.dataset_dir, images[i])
return images
def _convert_videos(
videos: Sequence["VideoInput"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if len(videos) == 0:
return None
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
return videos
def convert_alpaca(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts alpaca format dataset to the standard format.
"""
prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], str)
and isinstance(example[dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def convert_sharegpt(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts sharegpt format dataset to the standard format.
"""
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[dataset_attr.messages]
if (
dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], dict)
and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example
chosen = example[dataset_attr.chosen]
rejected = example[dataset_attr.rejected]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=False,
remove_columns=column_names,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment