Commit 53b3977b authored by dongchy920's avatar dongchy920
Browse files

Initial commit

parents
Pipeline #2841 failed with stages
in 0 seconds
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import signal
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import psutil
from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
from .locales import ALERTS
if is_gradio_available():
import gradio as gr
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
try:
children = psutil.Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Returns the available quantization bits.
"""
if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
available_bits = ["none", "8", "4"]
elif quantization_method == QuantizationMethod.HQQ.value:
available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
elif quantization_method == QuantizationMethod.EETQ.value:
available_bits = ["none", "8"]
return gr.Dropdown(choices=available_bits)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates arguments for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves arguments to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
def get_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
def get_time() -> str:
r"""
Gets current date and time.
"""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
r"""
Gets training infomation for monitor.
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
if len(trainer_log) != 0:
latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads saved arguments.
"""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]):
r"""
Saves arguments.
"""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
"""
config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml") and file_name not in config_files:
config_files.append(file_name)
return gr.Dropdown(choices=config_files)
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
"""
output_dirs = [f"train_{current_time}"]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for folder in os.listdir(save_dir):
output_dir = os.path.join(save_dir, folder)
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)
def create_ds_config() -> None:
r"""
Creates deepspeed config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.train.tuner import run_exp
def main():
run_exp()
def _mp_fn(index):
# For xla_spawn (TPUs)
run_exp()
if __name__ == "__main__":
main()
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from llamafactory.webui.interface import create_ui
def main():
gradio_ipv6 = os.getenv("GRADIO_IPV6", "0").lower() in ["true", "1"]
gradio_share = os.getenv("GRADIO_SHARE", "0").lower() in ["true", "1"]
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
if __name__ == "__main__":
main()
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen2.5-Coder-32B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "write a quick sort algorithm."
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
\ No newline at end of file
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import pytest
from datasets import load_dataset
from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.train.test_utils import load_train_dataset
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "kto",
"do_train": True,
"finetuning_type": "full",
"dataset": "kto_en_demo",
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3",
"cutoff_len": 8192,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
@pytest.mark.parametrize("num_samples", [16])
def test_feedback_data(num_samples: int):
train_dataset = load_train_dataset(**TRAIN_ARGS)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(DEMO_DATA, name="kto_en_demo", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
messages = original_data["messages"][index]
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_labels
assert train_dataset["kto_tags"][index] == original_data["label"][index]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
from typing import Dict, List
import pytest
from datasets import load_dataset
from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.train.test_utils import load_train_dataset
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "rm",
"do_train": True,
"finetuning_type": "full",
"dataset": "dpo_en_demo",
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3",
"cutoff_len": 8192,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
def _convert_sharegpt_to_openai(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
role_mapping = {"human": "user", "gpt": "assistant", "system": "system"}
new_messages = []
for message in messages:
new_messages.append({"role": role_mapping[message["from"]], "content": message["value"]})
return new_messages
@pytest.mark.parametrize("num_samples", [16])
def test_pairwise_data(num_samples: int):
train_dataset = load_train_dataset(**TRAIN_ARGS)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(DEMO_DATA, name="dpo_en_demo", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
chosen_messages = original_data["conversations"][index] + [original_data["chosen"][index]]
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True))
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
rejected_prompt_len = len(
ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
)
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
assert train_dataset["chosen_labels"][index] == ref_chosen_labels
assert train_dataset["rejected_input_ids"][index] == ref_rejected_input_ids
assert train_dataset["rejected_labels"][index] == ref_rejected_labels
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import pytest
from llamafactory.data.processors.processor_utils import infer_seqlen
@pytest.mark.parametrize(
"test_input,test_output",
[
((3000, 2000, 1000), (600, 400)),
((2000, 3000, 1000), (400, 600)),
((1000, 100, 1000), (900, 100)),
((100, 1000, 1000), (100, 900)),
((100, 500, 1000), (100, 500)),
((500, 100, 1000), (500, 100)),
((10, 10, 1000), (10, 10)),
],
)
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
assert test_output == infer_seqlen(*test_input)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import pytest
from datasets import load_dataset
from transformers import AutoTokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.train.test_utils import load_train_dataset
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
"template": "llama3",
"cutoff_len": 8192,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
@pytest.mark.parametrize("num_samples", [16])
def test_supervised_single_turn(num_samples: int):
train_dataset = load_train_dataset(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(TINY_DATA, split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
prompt = original_data["instruction"][index]
if original_data["input"][index]:
prompt += "\n" + original_data["input"][index]
messages = [
{"role": "user", "content": prompt},
{"role": "assistant", "content": original_data["output"][index]},
]
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.parametrize("num_samples", [8])
def test_supervised_multi_turn(num_samples: int):
train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.parametrize("num_samples", [4])
def test_supervised_train_on_prompt(num_samples: int):
train_dataset = load_train_dataset(
dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", train_on_prompt=True, **TRAIN_ARGS
)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
ref_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
assert train_dataset["input_ids"][index] == ref_ids
assert train_dataset["labels"][index] == ref_ids
@pytest.mark.parametrize("num_samples", [4])
def test_supervised_mask_history(num_samples: int):
train_dataset = load_train_dataset(
dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", mask_history=True, **TRAIN_ARGS
)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
messages = original_data["messages"][index]
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_label_ids
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import pytest
from datasets import load_dataset
from transformers import AutoTokenizer
from llamafactory.train.test_utils import load_train_dataset
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "ppo",
"do_train": True,
"finetuning_type": "full",
"reward_model": "",
"reward_model_type": "full",
"dataset": "system_chat",
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3",
"cutoff_len": 8192,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
@pytest.mark.parametrize("num_samples", [16])
def test_unsupervised_data(num_samples: int):
train_dataset = load_train_dataset(**TRAIN_ARGS)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
messages = original_data["messages"][index]
ref_ids = ref_tokenizer.apply_chat_template(messages)
ref_input_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
ref_labels = ref_ids[len(ref_input_ids) :]
assert train_dataset["input_ids"][index] == ref_input_ids
assert train_dataset["labels"][index] == ref_labels
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from PIL import Image
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
def test_base_collator():
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA, "template": "default"})
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
p = tokenizer_module["tokenizer"].pad_token_id
q = IGNORE_INDEX
features = [
{
"input_ids": [0, 1, 2, 3, 4, 5],
"attention_mask": [1, 1, 1, 1, 1, 1],
"labels": [q, q, 2, 3, 4, 5],
},
{
"input_ids": [6, 7],
"attention_mask": [1, 1],
"labels": [q, 7],
},
]
batch_input = data_collator(features)
expected_input = {
"input_ids": [
[0, 1, 2, 3, 4, 5, p, p],
[6, 7, p, p, p, p, p, p],
],
"attention_mask": [
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0],
],
"labels": [
[q, q, 2, 3, 4, 5, q, q],
[q, 7, q, q, q, q, q, q],
],
}
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
p = tokenizer_module["tokenizer"].pad_token_id
q = IGNORE_INDEX
s = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_start|>")
e = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_end|>")
m = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|image_pad|>")
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
features = [
{
"input_ids": [0, 1, 2, 3],
"attention_mask": [1, 1, 1, 1],
"labels": [0, 1, 2, 3],
},
]
batch_input = data_collator(features)
expected_input = {
"input_ids": [
[0, 1, 2, 3, s, m, m, m, m, e, p, p],
],
"attention_mask": [
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
],
"labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
],
**tokenizer_module["processor"].image_processor(fake_image),
}
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def test_4d_attention_mask():
o = 0.0
x = torch.finfo(torch.float16).min
attention_mask_with_indices = torch.tensor(
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
)
attention_mask_computed = prepare_4d_attention_mask(attention_mask_with_indices, torch.float16)
attention_mask_expected = torch.tensor(
[
[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, x, x, x, x],
]
],
[
[
[o, x, x, x, x, x],
[x, o, x, x, x, x],
[x, o, o, x, x, x],
[x, x, x, o, x, x],
[x, x, x, o, o, x],
[x, x, x, o, o, o],
]
],
],
dtype=torch.float16,
)
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
assert torch.all(attention_mask_computed == attention_mask_expected)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from datetime import datetime
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
FUNCTION = {"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
TOOLS = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
]
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
]
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
assert formatter.apply(content=json.dumps(TOOLS)) == [
"You have access to the following tools:\n"
"> Tool Name: test_tool\n"
"Tool Description: tool_desc\n"
"Tool Args:\n"
" - foo (string, required): foo_desc\n"
" - bar (number): bar_desc\n\n"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [test_tool])\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{"input": "hello world", "num_beams": 5}```)\n"""
"```\n"
]
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
"<|eot_id|>",
]
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
]
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"[AVAILABLE_TOOLS] " + json.dumps([wrapped_tool], ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
]
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
]
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
]
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, """
""""arguments": <args-json-object>}\n</tool_call><|im_end|>\n"""
]
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (
"""<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n</tool_call>"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
import pytest
import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from llamafactory.data.mm_plugin import BasePlugin
from llamafactory.model.loader import TokenizerModule
HF_TOKEN = os.getenv("HF_TOKEN")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MM_MESSAGES = [
{"role": "user", "content": "<image>What is in this image?"},
{"role": "assistant", "content": "A cat."},
]
TEXT_MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = []
NO_VIDEOS = []
IMGLENS = [1]
NO_IMGLENS = [0]
NO_VIDLENS = [0]
INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4]
BATCH_IDS = [[1] * 1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
return image_processor(images=IMAGES, return_tensors="pt")
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
assert len(batch_a[key]) == len(batch_b[key])
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
else:
assert batch_a[key] == batch_b[key]
def _load_tokenizer_module(model_name_or_path: str) -> "TokenizerModule":
model_args, *_ = get_infer_args({"model_name_or_path": model_name_or_path, "template": "default"})
return load_tokenizer(model_args)
def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
expected_input_ids: List[int] = INPUT_IDS,
expected_labels: List[int] = LABELS,
expected_mm_inputs: Dict[str, Any] = {},
expected_no_mm_inputs: Dict[str, Any] = {},
) -> None:
# test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
# test text_messages
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, processor) == TEXT_MESSAGES
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, tokenizer, processor) == (
INPUT_IDS,
LABELS,
)
_is_close(
plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, BATCH_IDS, processor),
expected_no_mm_inputs,
)
def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs)
def test_llava_plugin():
image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
check_inputs = {"plugin": llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_llava_next_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_llava_next_video_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
check_inputs = {"plugin": paligemma_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
]
check_inputs["expected_input_ids"] = [
tokenizer_module["tokenizer"].convert_tokens_to_ids(paligemma_plugin.image_token)
] * image_seqlen + INPUT_IDS
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
_check_plugin(**check_inputs)
def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
check_inputs = {"plugin": pixtral_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace(
"<image>",
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
+ "[IMG_END]",
)
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs)
def test_qwen2_vl_plugin():
image_seqlen = 4
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
check_inputs = {"plugin": qwen2_vl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_video_llava_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, List, Sequence
import pytest
from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.template import _get_jinja_template
from llamafactory.hparams import DataArguments
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
HF_TOKEN = os.getenv("HF_TOKEN")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "很高兴认识你!"},
]
def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
) -> None:
for input_ids, text in zip(batch_input_ids, batch_text):
assert input_ids == tokenizer.encode(text, add_special_tokens=False)
assert tokenizer.decode(input_ids) == text
def _check_single_template(
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
) -> List[str]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str + extra_str
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
return content_ids
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = "") -> None:
"""
Checks template for both the slow tokenizer and the fast tokenizer.
Args:
model_id: the model id on hugging face hub.
template_name: the template name.
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
extra_str: the extra string in the jinja template of the original tokenizer.
"""
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=False)
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=True)
assert slow_ids == fast_ids
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str = "很高兴认识你!<|eot_id|>"
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str_1 = "I am fine!<|eot_id|>"
prompt_str_2 = (
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str_2 = "很高兴认识你!<|eot_id|>"
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
)
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
tokenizer.chat_template = _get_jinja_template(template, tokenizer) # llama3 template no replace
assert tokenizer.chat_template != ref_tokenizer.chat_template
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_gemma_template():
prompt_str = (
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
"<start_of_turn>user\n你好<end_of_turn>\n"
"<start_of_turn>model\n"
)
answer_str = "很高兴认识你!"
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n")
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_llama3_template():
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str = "很高兴认识你!<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
def test_qwen_template():
prompt_str = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
@pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.")
def test_yi_template():
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from llamafactory.chat import ChatModel
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"do_sample": False,
"max_new_tokens": 1,
}
MESSAGES = [
{"role": "user", "content": "Hi"},
]
EXPECTED_RESPONSE = "_rho"
def test_chat():
chat_model = ChatModel(INFER_ARGS)
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
def test_stream_chat():
chat_model = ChatModel(INFER_ARGS)
response = ""
for token in chat_model.stream_chat(MESSAGES):
response += token
assert response == EXPECTED_RESPONSE
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from llamafactory.train.tuner import export_model, run_exp
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"do_train": True,
"finetuning_type": "lora",
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3",
"cutoff_len": 1,
"overwrite_cache": False,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
}
OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.parametrize(
"stage,dataset",
[
("pt", "c4_demo"),
("sft", "alpaca_en_demo"),
("dpo", "dpo_en_demo"),
("kto", "kto_en_demo"),
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = os.path.join("output", f"train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)
def test_export():
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.eval.template import get_eval_template
def test_eval_template_en():
support_set = [
{
"question": "Fewshot question",
"A": "Fewshot1",
"B": "Fewshot2",
"C": "Fewshot3",
"D": "Fewshot4",
"answer": "B",
}
]
example = {
"question": "Target question",
"A": "Target1",
"B": "Target2",
"C": "Target3",
"D": "Target4",
"answer": "C",
}
template = get_eval_template(name="en")
messages = template.format_example(example, support_set=support_set, subject_name="SubName")
assert messages == [
{
"role": "user",
"content": (
"The following are multiple choice questions (with answers) about SubName.\n\n"
"Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
),
},
{"role": "assistant", "content": "B"},
{
"role": "user",
"content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
},
{"role": "assistant", "content": "C"},
]
def test_eval_template_zh():
support_set = [
{
"question": "示例问题",
"A": "示例答案1",
"B": "示例答案2",
"C": "示例答案3",
"D": "示例答案4",
"answer": "B",
}
]
example = {
"question": "目标问题",
"A": "目标答案1",
"B": "目标答案2",
"C": "目标答案3",
"D": "目标答案4",
"answer": "C",
}
template = get_eval_template(name="zh")
messages = template.format_example(example, support_set=support_set, subject_name="主题")
assert messages == [
{
"role": "user",
"content": (
"以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
"示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
),
},
{"role": "assistant", "content": "B"},
{
"role": "user",
"content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
},
{"role": "assistant", "content": "C"},
]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from llamafactory.train.test_utils import load_infer_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"template": "llama3",
}
def test_attention():
attention_available = ["disabled"]
if is_torch_sdpa_available():
attention_available.append("sdpa")
if is_flash_attn_2_available():
attention_available.append("fa2")
llama_attention_classes = {
"disabled": "LlamaAttention",
"sdpa": "LlamaSdpaAttention",
"fa2": "LlamaFlashAttention2",
}
for requested_attention in attention_available:
model = load_infer_model(flash_attn=requested_attention, **INFER_ARGS)
for module in model.modules():
if "Attention" in module.__class__.__name__:
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from llamafactory.extras.misc import get_current_device
from llamafactory.train.test_utils import load_train_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
"lora_target": "all",
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"overwrite_cache": True,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
def test_checkpointing_enable():
model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert getattr(module, "gradient_checkpointing") is True
def test_checkpointing_disable():
model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert getattr(module, "gradient_checkpointing") is False
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
def test_upcast_layernorm():
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
for name, param in model.named_parameters():
if param.ndim == 1 and "norm" in name:
assert param.dtype == torch.float32
def test_upcast_lmhead_output():
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
assert outputs.dtype == torch.float32
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data
@pytest.mark.parametrize(
"attention_mask,golden_seq_lens",
[
(
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
],
[2, 3, 1, 2, 3],
),
(
[[1]],
[1],
),
],
)
def test_get_seqlens_in_batch(attention_mask, golden_seq_lens):
attention_mask_with_indices = torch.tensor(attention_mask)
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices)
assert torch.all(seqlens_in_batch == torch.tensor(golden_seq_lens))
@pytest.mark.parametrize(
"attention_mask,golden_indices,golden_cu_seqlens,golden_max_seqlen",
[
(
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
],
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11],
[0, 2, 5, 6, 8, 11],
3,
),
(
[[1]],
[0],
[0, 1],
1,
),
],
)
def test_get_unpad_data(attention_mask, golden_indices, golden_cu_seqlens, golden_max_seqlen):
attention_mask_with_indices = torch.tensor(attention_mask)
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices)
assert torch.all(indices == torch.tensor(golden_indices))
assert torch.all(cu_seqlens == torch.tensor(golden_cu_seqlens, dtype=torch.int32))
assert max_seqlen_in_batch == golden_max_seqlen
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"template": "llama3",
"infer_dtype": "float16",
}
@pytest.fixture
def fix_valuehead_cpu_loading():
patch_valuehead_model()
def test_base():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA)
compare_model(model, ref_model)
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_valuehead():
model = load_infer_model(add_valuehead=True, **INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, add_valuehead=True)
compare_model(model, ref_model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment