"docker/rocm.Dockerfile" did not exist on "7c3a12c0002e33fed1e72f4157e74a64a998f251"
Commit 53b3977b authored by dongchy920's avatar dongchy920
Browse files

Initial commit

parents
Pipeline #2841 failed with stages
in 0 seconds
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest
if TYPE_CHECKING:
from ..data.mm_plugin import ImageInput, VideoInput
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = logging.get_logger(__name__)
class VllmEngine(BaseEngine):
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
model_args.infer_dtype = "float16"
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": model_args.trust_remote_code,
"download_dir": model_args.cache_dir,
"dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util,
"disable_log_stats": True,
"disable_log_requests": True,
"enforce_eager": model_args.vllm_enforce_eager,
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
if getattr(config, "is_yi_vl_derived_model", None):
import vllm.model_executor.models.llava
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None:
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
self.lora_request = None
async def _generate(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
if max_new_tokens:
max_tokens = max_new_tokens
sampling_params = SamplingParams(
n=num_return_sequences,
repetition_penalty=(
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,
skip_special_tokens=self.generating_args["skip_special_tokens"],
)
if images is not None: # add image features
multi_modal_data = {"image": []}
for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
multi_modal_data["image"].append(image)
else:
multi_modal_data = None
result_generator = self.model.generate(
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
sampling_params=sampling_params,
request_id=request_id,
lora_request=self.lora_request,
)
return result_generator
@override
async def chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for request_output in generator:
final_output = request_output
results = []
for output in final_output.outputs:
results.append(
Response(
response_text=output.text,
response_length=len(output.token_ids),
prompt_length=len(final_output.prompt_token_ids),
finish_reason=output.finish_reason,
)
)
return results
@override
async def stream_chat(
self,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text
yield delta_text
@override
async def get_scores(
self,
batch_input: List[str],
**input_kwargs,
) -> List[float]:
raise NotImplementedError("vLLM engine does not support get_scores.")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import subprocess
import sys
from enum import Enum, unique
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli eval -h: evaluate models |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "-" * 70
)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
logger = logging.get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
def main():
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
if command == Command.API:
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.format(
nnodes=os.getenv("NNODES", "1"),
node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
)
.split()
)
sys.exit(process.returncode)
else:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError(f"Unknown command: {command}.")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .collator import (
KTODataCollatorWithPadding,
MultiModalDataCollatorForSeq2Seq,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [
"KTODataCollatorWithPadding",
"MultiModalDataCollatorForSeq2Seq",
"PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask",
"Role",
"split_dataset",
"get_dataset",
"TEMPLATES",
"Template",
"get_template_and_fix_tokenizer",
]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras import logging
from .data_utils import Role
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr
logger = logging.get_logger(__name__)
def _convert_images(
images: Union["ImageInput", Sequence["ImageInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
if not isinstance(images, list):
images = [images]
elif len(images) == 0:
return None
else:
images = images[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.image_dir, images[i])
return images
def _convert_videos(
videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None
else:
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos
def convert_alpaca(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts alpaca format dataset to the standard format.
"""
prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
query.append(example[dataset_attr.prompt])
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])
prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], str)
and isinstance(example[dataset_attr.rejected], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.chosen]},
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def convert_sharegpt(
example: Dict[str, Any],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Dict[str, Any]:
r"""
Converts sharegpt format dataset to the standard format.
"""
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
dataset_attr.observation_tag: Role.OBSERVATION.value,
dataset_attr.function_tag: Role.FUNCTION.value,
dataset_attr.system_tag: Role.SYSTEM.value,
}
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[dataset_attr.messages]
if (
dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if example[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(example[dataset_attr.chosen], dict)
and isinstance(example[dataset_attr.rejected], dict)
): # pairwise example
chosen = example[dataset_attr.chosen]
rejected = example[dataset_attr.rejected]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
}
return output
def align_dataset(
dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=False,
remove_columns=column_names,
**kwargs,
)
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from transformers import ProcessorMixin
from .template import Template
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
# input
[[1, 1, 2, 2, 2, 0]]
# output
[
[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, x, x, x, x],
]
]
]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz, seq_len = attention_mask_with_indices.size()
min_dtype = torch.finfo(dtype).min
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask = torch.where(expanded_mask != 0, 1, 0)
# Create a block-diagonal mask.
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
return attention_mask_4d
@dataclass
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
"""
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __post_init__(self):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
batch_images.extend(images)
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])
if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
else:
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
batch_images = fake_images
batch_input_ids[0] = features[0]["input_ids"]
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None),
attention_mask=features["attention_mask"],
)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
orig_len = cross_attention_mask.size(1)
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
return features
@dataclass
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
return features
@dataclass
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
for key in ("chosen", "rejected"):
for feature in features:
target_feature = {
"input_ids": feature[f"{key}_input_ids"],
"attention_mask": feature[f"{key}_attention_mask"],
"labels": feature[f"{key}_labels"],
"images": feature["images"],
"videos": feature["videos"],
}
concatenated_features.append(target_feature)
return super().__call__(concatenated_features)
@dataclass
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
target_feature = {
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
"images": feature["images"],
"videos": feature["videos"],
}
kl_feature = {
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
"images": feature["images"],
"videos": feature["videos"],
}
target_features.append(target_feature)
kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
batch["kto_tags"] = torch.tensor(kto_tags)
return batch
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras import logging
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from ..hparams import DataArguments
logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
OBSERVATION = "observation"
class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
"""
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
return DatasetDict({"train": train_set, "validation": val_set})
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing_extensions import override
from .data_utils import SLOTS
from .tool_utils import FunctionCall, get_tool_utils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[str] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
...
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
raise NotImplementedError
@dataclass
class EmptyFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.")
@override
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.")
@override
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
if isinstance(slot, str):
for name, value in kwargs.items():
if not isinstance(value, str):
raise RuntimeError(f"Expected a string, got {value}")
slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
return elements
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List["FunctionCall"] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
for tool_call in tool_calls:
functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = []
for slot in self.slots:
if slot == "{{content}}":
elements += self.tool_utils.function_formatter(functions)
else:
elements.append(slot)
return elements
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
return self.tool_utils.tool_extractor(content)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments, ModelArguments
from .data_utils import DatasetModule
from .parser import DatasetAttr
from .template import Template
logger = logging.get_logger(__name__)
def _load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "file":
data_files = []
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
else:
raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else:
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=data_args.streaming,
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=data_args.streaming,
)
else:
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
)
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
indexes = np.concatenate((indexes, expand_indexes), axis=0)
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args, training_args)
def _get_merged_dataset(
dataset_names: Optional[Sequence[str]],
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Gets the merged datasets in the standard format.
"""
if dataset_names is None:
return None
datasets = []
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
return merge_dataset(datasets, data_args, seed=training_args.seed)
def _get_preprocessed_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None:
return None
preprocess_func, print_function = get_preprocess_and_print_func(
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset",
)
dataset = dataset.map(
preprocess_func,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
**kwargs,
)
if training_args.should_log:
try:
print("eval example:" if is_eval else "training example:")
print_function(next(iter(dataset)))
except StopIteration:
if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
else:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset
def get_dataset(
template: "Template",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule":
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {}
if isinstance(tokenized_data, DatasetDict):
if "train" in tokenized_data:
dataset_module["train_dataset"] = tokenized_data["train"]
if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"]
else: # Dataset
dataset_module["train_dataset"] = tokenized_data
if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
return dataset_module
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
# Load and preprocess dataset
with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
with training_args.main_process_first(desc="pre-process dataset"):
dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
if data_args.val_size > 1e-6:
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
else:
dataset_dict = {}
if dataset is not None:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
dataset_dict["train"] = dataset
if eval_dataset is not None:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
dataset_dict["validation"] = eval_dataset
dataset_dict = DatasetDict(dataset_dict)
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0)
dataset_module = {}
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
return dataset_module
import math
from copy import deepcopy
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
if is_pillow_available():
from PIL import Image
from PIL.Image import Image as ImageObject
if is_pyav_available():
import av
if is_transformers_version_greater_than("4.45.0"):
from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask,
)
if TYPE_CHECKING:
from av.stream import Stream
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
class EncodedImage(TypedDict):
path: Optional[str]
bytes: Optional[bytes]
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
VideoInput = str
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
) -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
"""
batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens):
image_seqlen = imglen * getattr(processor, "image_seqlen")
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids
class BasePlugin:
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token
self.video_token = video_token
self.expand_mm_tokens = True
def _validate_input(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
) -> None:
r"""
Validates if this model accepts the input modalities.
"""
if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.")
if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.")
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
r"""
Pre-processes a single image.
"""
image_resolution: int = kwargs.get("image_resolution")
if (image.width * image.height) > image_resolution:
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.NEAREST)
if image.mode != "RGB":
image = image.convert("RGB")
return image
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
r"""
Computes video sample frames according to fps.
"""
video_fps: float = kwargs.get("video_fps")
video_maxlen: int = kwargs.get("video_maxlen")
total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
sample_frames = min(total_frames, video_maxlen, sample_frames)
return math.floor(sample_frames)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including reading and pre-processing.
"""
results = []
for image in images:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, bytes):
image = Image.open(BytesIO(image))
elif isinstance(image, dict):
if image["bytes"] is not None:
image = Image.open(BytesIO(image["bytes"]))
else:
image = Image.open(image["path"])
if not isinstance(image, ImageObject):
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
results.append(self._preprocess_image(image, **kwargs))
return results
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
r"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = self._regularize_images(frames, **kwargs)
results.append(frames)
return results
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
input_dict = {"images": None} # default key
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
)
input_dict["images"] = images
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
input_dict["videos"] = videos
mm_inputs = {}
if image_processor != video_processor:
if input_dict.get("images") is not None:
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
if input_dict.get("videos") is not None:
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
return mm_inputs
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
self._validate_input(images, videos)
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
self._validate_input(images, videos)
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
Builds batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
return {}
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "image_sizes" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seqlen + labels
return input_ids, labels
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
class PixtralPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = mm_inputs.get("image_sizes", None)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError("Cannot get image input sizes.")
if self.expand_mm_tokens:
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
else:
replace_str = image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
mm_inputs.pop("image_sizes", None)
return mm_inputs
class Qwen2vlPlugin(BasePlugin):
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
image = super()._preprocess_image(image, **kwargs)
if min(image.width, image.height) < 28:
width, height = max(image.width, 28), max(image.height, 28)
image = image.resize((width, height), resample=Image.NEAREST)
if image.width / image.height > 200:
width, height = image.height * 180, image.height
image = image.resize((width, height), resample=Image.NEAREST)
if image.height / image.width > 200:
width, height = image.width, image.width * 180
image = image.resize((width, height), resample=Image.NEAREST)
return image
@override
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
frames.append(frames[-1])
frames = self._regularize_images(frames, **kwargs)
results.append(frames)
return results
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
)
num_video_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_frames = 0
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
if has_images or has_videos:
if self.expand_mm_tokens:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
return image_processor([[image] for image in images], return_tensors="pt")
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
if len(images) != len(batch_ids):
raise ValueError("Mllama only supports one image per sample.")
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = torch.from_numpy(
convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs
PLUGINS = {
"base": BasePlugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin,
"video_llava": VideoLlavaPlugin,
"mllama": MllamaPlugin,
}
def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
video_token: Optional[str] = None,
) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None)
if plugin_class is None:
raise ValueError(f"Multimodal plugin `{name}` not found.")
return plugin_class(image_token, video_token)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Sequence
from transformers.utils import cached_file
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope, use_openmind
@dataclass
class DatasetAttr:
r"""
Dataset attributes.
"""
# basic configs
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
# extra configs
subset: Optional[str] = None
split: str = "train"
folder: Optional[str] = None
num_samples: Optional[int] = None
# common columns
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
videos: Optional[str] = None
# rlhf columns
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
# alpaca columns
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
# sharegpt columns
messages: Optional[str] = "conversations"
# sharegpt tags
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
def __repr__(self) -> str:
return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
r"""
Gets the attributes of the datasets.
"""
if dataset_names is None:
dataset_names = []
if dataset_dir == "ONLINE":
dataset_info = None
else:
if dataset_dir.startswith("REMOTE:"):
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
else:
config_path = os.path.join(dataset_dir, DATA_CONFIG)
try:
with open(config_path) as f:
dataset_info = json.load(f)
except Exception as err:
if len(dataset_names) != 0:
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
dataset_info = None
dataset_list: List["DatasetAttr"] = []
for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope():
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
if name not in dataset_info:
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
has_hf_url = "hf_hub_url" in dataset_info[name]
has_ms_url = "ms_hub_url" in dataset_info[name]
has_om_url = "om_hub_url" in dataset_info[name]
if has_hf_url or has_ms_url or has_om_url:
if has_ms_url and (use_modelscope() or not has_hf_url):
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
elif has_om_url and (use_openmind() or not has_hf_url):
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
else:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("split", dataset_info[name], default="train")
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
tag_names = (
"role_tag",
"content_tag",
"user_tag",
"assistant_tag",
"observation_tag",
"function_tag",
"system_tag",
)
for tag in tag_names:
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
dataset_list.append(dataset_attr)
return dataset_list
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..hparams import DataArguments
from .template import Template
def get_preprocess_and_print_func(
data_args: "DataArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
return TypedSequence.__init__(
self,
data,
type=kwargs.pop("type", None),
try_type=kwargs.pop("try_type", None),
optimized_int_type=kwargs.pop("optimized_int_type", None),
)
OptimizedTypedSequence.__init__ = __init__
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto":
preprocess_func = partial(
preprocess_feedback_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_feedback_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len)
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_feedback_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
kl_response=kl_response[i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_pairwise_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
# consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from ...hparams import DataArguments
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect
from typing import List, Sequence, Tuple
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1)
def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
r"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []
while numbers:
current_knapsack = []
remaining_capacity = capacity
while True:
index = search_for_fit(numbers, remaining_capacity)
if index == -1:
break # no more numbers fit in this knapsack
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack
knapsacks.append(current_knapsack)
return knapsacks
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
max_target_len = cutoff_len - source_len
else: # truncate both
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
max_source_len = max(cutoff_len - new_target_len, 0)
new_source_len = min(max_source_len, source_len)
return new_source_len, new_target_len
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if train_on_prompt:
source_label = source_ids
elif template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
if mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ..data_utils import Role
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_unsupervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int]]:
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_unsupervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras import logging
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
from .tool_utils import FunctionCall
logger = logging.get_logger(__name__)
@dataclass
class Template:
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
mm_plugin: "BasePlugin"
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
answer_ids = encoded_messages[-1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts tool message.
"""
return self.format_tools.extract(content)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0:
elements += self.format_prefix.apply()
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return encoded_messages
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
token_ids = []
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids
@dataclass
class Llama2Template(Template):
@override
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
system: str,
tools: str,
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0:
elements += self.format_prefix.apply()
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return encoded_messages
TEMPLATES: Dict[str, "Template"] = {}
def _register_template(
name: str,
format_user: Optional["Formatter"] = None,
format_assistant: Optional["Formatter"] = None,
format_system: Optional["Formatter"] = None,
format_function: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Sequence[str] = [],
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None:
r"""
Registers a chat template.
To add the following chat template:
```
[HUMAN]:
user prompt here
[AI]:
model response here
[HUMAN]:
user prompt here
[AI]:
model response here
```
The corresponding code should be:
```
_register_template(
name="custom",
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
```
"""
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=default_slots)
default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter,
format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
mm_plugin=mm_plugin,
)
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else:
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str:
return content.replace("'", r"\'")
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
if len(slot_pieces) > 1:
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return " + ".join(slot_items)
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
if prefix:
jinja_template += "{{ " + prefix + " }}"
if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in loop_messages %}"
jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
)
jinja_template += "{{ " + assistant_message + " }}"
jinja_template += "{% endif %}"
jinja_template += "{% endfor %}"
return jinja_template
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder
else:
template = TEMPLATES.get(data_args.template, None)
if template is None:
raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
if tokenizer.chat_template is None or template.replace_jinja_template:
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
return template
_register_template(
name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
),
replace_jinja_template=True,
)
_register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_separator=EmptyFormatter(slots=["###"]),
default_system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words=["</s>"],
efficient_eos=True,
)
_register_template(
name="atom",
format_user=StringFormatter(
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
),
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
)
_register_template(
name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
_register_template(
name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True,
)
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
_register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
)
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True,
)
_register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
_register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
replace_jinja_template=True,
)
# copied from chatml template
_register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
replace_jinja_template=True,
)
_register_template(
name="codegeex2",
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
)
_register_template(
name="codegeex4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
default_system=(
"你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
_register_template(
name="cohere",
format_user=StringFormatter(
slots=[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
# copied from chatml template
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n"
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
"from writing to coding (using markdown for code blocks — remember to use ``` with "
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
"capabilities. You avoid stereotyping and provide balanced perspectives on "
"controversial topics. You do not provide song lyrics, poems, or news articles and "
"do not divulge details of your training data.)\nThis is your system prompt, "
"guiding your responses. Do not reference it, just respond to the user. If you find "
"yourself talking about this message, stop. You should be responding appropriately "
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
"developed by DeepSeek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer.\n"
),
)
_register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="empty",
efficient_eos=True,
)
_register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_separator=EmptyFormatter(slots=["\n"]),
efficient_eos=True,
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
)
_register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
)
_register_template(
name="granite3",
format_user=StringFormatter(
slots=[
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
]
),
format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
format_system=StringFormatter(slots=["<unk>{{content}}"]),
efficient_eos=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<eoa>"],
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
)
# copied from llama2 template
_register_template(
name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
)
_register_template(
name="llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
)
# copied from llama3 template
_register_template(
name="mllama",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
)
# copied from vicuna template
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
# copied from vicuna template
_register_template(
name="llava_next",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from llama3 template
_register_template(
name="llava_next_llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from mistral template
_register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from chatml template
_register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from chatml template
_register_template(
name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from vicuna template
_register_template(
name="llava_next_video",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
# copied from mistral template
_register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
# copied from chatml template
_register_template(
name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
# copied from chatml template
_register_template(
name="marco",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
)
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="openchat-3.6",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
)
# copied from chatml template
_register_template(
name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are OpenCoder, created by OpenCoder Team.",
stop_words=["<|im_end|>"],
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
# copied from gemma template
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
# copied from chatml template
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
)
# copied from chatml template
_register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
_register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|im_end|>"],
)
# copied from llama3 template
_register_template(
name="skywork_o1",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, "
"you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. "
"After completing your thoughts, you then provide a detailed explanation of the solution process "
"in your response."
),
stop_words=["<|eot_id|>"],
)
_register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
efficient_eos=True,
)
_register_template(
name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
)
_register_template(
name="telechat2",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}"]),
default_system=(
"你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。"
),
)
_register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
replace_jinja_template=True,
)
_register_template(
name="video_llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
)
_register_template(
name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
default_system=(
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
),
)
_register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
_register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
stop_words=["<|End|>"],
)
# copied from chatml template
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
),
stop_words=["###"],
efficient_eos=True,
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<eod>"],
replace_eos=True,
)
_register_template(
name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are Zephyr, a helpful assistant.",
)
_register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
class FunctionCall(NamedTuple):
name: str
arguments: str
DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}])\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```)\n"""
"```\n"
)
GLM4_TOOL_PROMPT = (
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)
LLAMA3_TOOL_PROMPT = (
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
"Do not use variables.\n\n{tool_text}"
)
QWEN_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
""""arguments": <args-json-object>}}\n</tool_call><|im_end|>\n"""
)
@dataclass
class ToolUtils(ABC):
"""
Base class for tool utilities.
"""
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
...
@staticmethod
@abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
r"""
Generates the assistant message including all the tool calls.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
...
class DefaultToolUtils(ToolUtils):
r"""
Default tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", ""
if name in tool["parameters"].get("required", []):
required = ", required"
if param.get("enum", None):
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
if param.get("items", None):
items = ", where each item should be {}".format(param["items"].get("type", ""))
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
items=items,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return [function_text]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
results = []
for match in action_match:
tool_name = match[0].strip()
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append(FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content
return results
class GLM4ToolUtils(ToolUtils):
r"""
GLM-4 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
return [f"{functions[0].name}\n{functions[0].arguments}"]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "\n" not in content:
return content
tool_name, tool_input = content.split("\n", maxsplit=1)
try:
arguments = json.loads(tool_input.strip())
except json.JSONDecodeError:
return content
return [FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))]
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
try:
tool = json.loads(content.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "parameters" not in tool:
return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils):
r"""
Mistral v0.3 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return ["[" + ", ".join(function_texts) + "]"]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
try:
tools = json.loads(content.strip())
except json.JSONDecodeError:
return content
if not isinstance(tools, list):
tools = [tools]
results = []
for tool in tools:
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
class QwenToolUtils(ToolUtils):
r"""
Qwen 2.5 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return ["\n".join(function_texts)]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content)
if not tool_match:
return content
results = []
for tool in tool_match:
try:
tool = json.loads(tool.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
}
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError(f"Tool utils `{name}` not found.")
return tool_utils
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment