Commit 24534501 authored by mashun1's avatar mashun1
Browse files

parallel_tool

parent c4ba4563
...@@ -32,7 +32,7 @@ if TYPE_CHECKING: ...@@ -32,7 +32,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.rope_scaling is None: if model_args.rope_scaling is None:
return return
...@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ ...@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.") logger.warning_rank0("Current model does not support RoPE scaling.")
return return
rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum if hasattr(config, "max_position_embeddings"):
if model_args.model_max_length is not None: old_max_length = getattr(config, "max_position_embeddings", None)
if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC: else:
logger.warning_rank0("Cannot find the max position embeddings in the config.")
return
if model_args.model_max_length is not None: # training
if model_args.model_max_length <= old_max_length:
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
return
if model_args.rope_scaling == RopeScaling.DYNAMIC:
logger.warning_rank0( logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. " "Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653" "See: https://github.com/huggingface/transformers/pull/24653"
) )
current_max_length = getattr(config, "max_position_embeddings", None) rope_factor = float(math.ceil(model_args.model_max_length / old_max_length))
if (not current_max_length) or model_args.model_max_length <= current_max_length: else: # inference
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.") rope_factor = 2.0
return
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") rope_kwargs = {
setattr(config, "max_position_embeddings", model_args.model_max_length) "rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling), # handle enum
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) "factor": rope_factor,
if model_args.rope_scaling == RopeScaling.DYNAMIC: }
rope_kwargs["original_max_position_embeddings"] = current_max_length setattr(config, "max_position_embeddings", old_max_length * rope_factor)
elif model_args.rope_scaling == RopeScaling.LLAMA3: logger.info_rank0(f"Enlarge max model length from {old_max_length} to {old_max_length * rope_factor}.")
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["low_freq_factor"] = 1.0 if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
rope_kwargs["high_freq_factor"] = 4.0 rope_kwargs["original_max_position_embeddings"] = old_max_length
else: elif model_args.rope_scaling == RopeScaling.LLAMA3:
rope_kwargs["factor"] = 2.0 rope_kwargs["original_max_position_embeddings"] = old_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
setattr(config, "rope_scaling", rope_kwargs) setattr(config, "rope_scaling", rope_kwargs)
logger.info_rank0( logger.info_rank0(
......
...@@ -24,6 +24,7 @@ import transformers.models ...@@ -24,6 +24,7 @@ import transformers.models
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...extras import logging from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -281,7 +282,7 @@ _register_composite_model( ...@@ -281,7 +282,7 @@ _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"], language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
...@@ -290,6 +291,6 @@ _register_composite_model( ...@@ -290,6 +291,6 @@ _register_composite_model(
model_type="qwen2_5_vl", model_type="qwen2_5_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"], language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
...@@ -85,8 +85,8 @@ def patch_processor( ...@@ -85,8 +85,8 @@ def patch_processor(
setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video) setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
def patch_config( def patch_config(
...@@ -102,8 +102,8 @@ def patch_config( ...@@ -102,8 +102,8 @@ def patch_config(
else: else:
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
configure_attn_implementation(config, model_args, is_trainable) configure_attn_implementation(config, model_args)
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args)
configure_longlora(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
......
...@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler(*args, **kwargs)
@override @override
def get_batch_samples(self, *args, **kwargs): def get_batch_samples(self, *args, **kwargs):
......
...@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge ...@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge
if TYPE_CHECKING: if TYPE_CHECKING:
import torch.utils.data
from transformers import PreTrainedModel, ProcessorMixin from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
...@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler.""" r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self, *args, **kwargs)
@override @override
def get_batch_samples(self, *args, **kwargs): def get_batch_samples(self, *args, **kwargs):
......
...@@ -70,11 +70,11 @@ class CustomTrainer(Trainer): ...@@ -70,11 +70,11 @@ class CustomTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler(*args, **kwargs)
@override @override
def compute_loss(self, model, inputs, *args, **kwargs): def compute_loss(self, model, inputs, *args, **kwargs):
......
...@@ -77,12 +77,23 @@ def run_pt( ...@@ -77,12 +77,23 @@ def run_pt(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
......
...@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer): ...@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler(*args, **kwargs)
@override @override
def compute_loss( def compute_loss(
......
...@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler(*args, **kwargs)
@override @override
def compute_loss(self, model, inputs, *args, **kwargs): def compute_loss(self, model, inputs, *args, **kwargs):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import json import json
import os import os
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
...@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup ...@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
) )
@contextmanager
def update_attr(obj: Any, name: str, value: Any):
old_value = getattr(obj, name, None)
setattr(obj, name, value)
yield
setattr(obj, name, old_value)
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager
...@@ -191,40 +200,42 @@ class WebChatModel(ChatModel): ...@@ -191,40 +200,42 @@ class WebChatModel(ChatModel):
temperature: float, temperature: float,
skip_special_tokens: bool, skip_special_tokens: bool,
escape_html: bool, escape_html: bool,
enable_thinking: bool,
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]: ) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
r"""Generate output text in stream. r"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ... Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages Output: infer.chatbot, infer.messages
""" """
chatbot.append({"role": "assistant", "content": ""}) with update_attr(self.engine.template, "enable_thinking", enable_thinking):
response = "" chatbot.append({"role": "assistant", "content": ""})
for new_text in self.stream_chat( response = ""
messages, for new_text in self.stream_chat(
system, messages,
tools, system,
images=[image] if image else None, tools,
videos=[video] if video else None, images=[image] if image else None,
audios=[audio] if audio else None, videos=[video] if video else None,
max_new_tokens=max_new_tokens, audios=[audio] if audio else None,
top_p=top_p, max_new_tokens=max_new_tokens,
temperature=temperature, top_p=top_p,
skip_special_tokens=skip_special_tokens, temperature=temperature,
): skip_special_tokens=skip_special_tokens,
response += new_text ):
if tools: response += new_text
result = self.engine.template.extract_tool(response) if tools:
else: result = self.engine.template.extract_tool(response)
result = response else:
result = response
if isinstance(result, list):
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result] if isinstance(result, list):
tool_calls = json.dumps(tool_calls, ensure_ascii=False) tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}] tool_calls = json.dumps(tool_calls, ensure_ascii=False)
bot_text = "```json\n" + tool_calls + "\n```" output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
else: bot_text = "```json\n" + tool_calls + "\n```"
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] else:
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words) output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages
...@@ -205,6 +205,14 @@ def load_eval_results(path: os.PathLike) -> str: ...@@ -205,6 +205,14 @@ def load_eval_results(path: os.PathLike) -> str:
return f"```json\n{result}\n```\n" return f"```json\n{result}\n```\n"
def calculate_pixels(pixels: str) -> int:
r"""Calculate the number of pixels from the expression."""
if "*" in pixels:
return int(pixels.split("*")[0]) * int(pixels.split("*")[1])
else:
return int(pixels)
def create_ds_config() -> None: def create_ds_config() -> None:
r"""Create deepspeed config in the current directory.""" r"""Create deepspeed config in the current directory."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
......
...@@ -79,6 +79,7 @@ def create_chat_box( ...@@ -79,6 +79,7 @@ def create_chat_box(
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
skip_special_tokens = gr.Checkbox(value=True) skip_special_tokens = gr.Checkbox(value=True)
escape_html = gr.Checkbox(value=True) escape_html = gr.Checkbox(value=True)
enable_thinking = gr.Checkbox(value=True)
clear_btn = gr.Button() clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
...@@ -103,6 +104,7 @@ def create_chat_box( ...@@ -103,6 +104,7 @@ def create_chat_box(
temperature, temperature,
skip_special_tokens, skip_special_tokens,
escape_html, escape_html,
enable_thinking,
], ],
[chatbot, messages], [chatbot, messages],
) )
...@@ -127,6 +129,7 @@ def create_chat_box( ...@@ -127,6 +129,7 @@ def create_chat_box(
temperature=temperature, temperature=temperature,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
escape_html=escape_html, escape_html=escape_html,
enable_thinking=enable_thinking,
clear_btn=clear_btn, clear_btn=clear_btn,
), ),
) )
...@@ -18,7 +18,7 @@ from ...data import TEMPLATES ...@@ -18,7 +18,7 @@ from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import save_config from ..common import save_config
from ..control import can_quantize, can_quantize_to, get_model_info, list_checkpoints from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints
if is_gradio_available(): if is_gradio_available():
...@@ -49,7 +49,7 @@ def create_top() -> dict[str, "Component"]: ...@@ -49,7 +49,7 @@ def create_top() -> dict[str, "Component"]:
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
) ).then(check_template, [lang, template])
model_name.input(save_config, inputs=[lang, model_name], queue=False) model_name.input(save_config, inputs=[lang, model_name], queue=False)
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then( finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
......
...@@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro = gr.Checkbox() use_llama_pro = gr.Checkbox()
with gr.Column(): with gr.Column():
enable_thinking = gr.Checkbox(value=True)
report_to = gr.Dropdown( report_to = gr.Dropdown(
choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"], choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"],
value=["none"], value="none",
allow_custom_value=True, allow_custom_value=True,
multiselect=True,
) )
input_elems.update( input_elems.update(
...@@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history, mask_history,
resize_vocab, resize_vocab,
use_llama_pro, use_llama_pro,
enable_thinking,
report_to, report_to,
} }
) )
...@@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history=mask_history, mask_history=mask_history,
resize_vocab=resize_vocab, resize_vocab=resize_vocab,
use_llama_pro=use_llama_pro, use_llama_pro=use_llama_pro,
enable_thinking=enable_thinking,
report_to=report_to, report_to=report_to,
) )
) )
...@@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
) )
) )
with gr.Accordion(open=False) as mm_tab:
with gr.Row():
freeze_vision_tower = gr.Checkbox(value=True)
freeze_multi_modal_projector = gr.Checkbox(value=True)
freeze_language_model = gr.Checkbox(value=False)
with gr.Row():
image_max_pixels = gr.Textbox(value="768*768")
image_min_pixels = gr.Textbox(value="32*32")
video_max_pixels = gr.Textbox(value="256*256")
video_min_pixels = gr.Textbox(value="16*16")
input_elems.update(
{
freeze_vision_tower,
freeze_multi_modal_projector,
freeze_language_model,
image_max_pixels,
image_min_pixels,
video_max_pixels,
video_min_pixels,
}
)
elem_dict.update(
dict(
mm_tab=mm_tab,
freeze_vision_tower=freeze_vision_tower,
freeze_multi_modal_projector=freeze_multi_modal_projector,
freeze_language_model=freeze_language_model,
image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels,
video_max_pixels=video_max_pixels,
video_min_pixels=video_min_pixels,
)
)
with gr.Accordion(open=False) as galore_tab: with gr.Accordion(open=False) as galore_tab:
with gr.Row(): with gr.Row():
use_galore = gr.Checkbox() use_galore = gr.Checkbox()
......
...@@ -84,6 +84,17 @@ def get_model_info(model_name: str) -> tuple[str, str]: ...@@ -84,6 +84,17 @@ def get_model_info(model_name: str) -> tuple[str, str]:
return get_model_path(model_name), get_template(model_name) return get_model_path(model_name), get_template(model_name)
def check_template(lang: str, template: str) -> None:
r"""Check if an instruct model is used.
Please use queue=True to show the warning message.
Inputs: top.lang, top.template
"""
if template == "default":
gr.Warning(ALERTS["warn_no_instruct"][lang])
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]: def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]:
r"""Get training infomation for monitor. r"""Get training infomation for monitor.
......
...@@ -871,6 +871,28 @@ LOCALES = { ...@@ -871,6 +871,28 @@ LOCALES = {
"info": "拡張ブロックのパラメータのみをトレーニングします。", "info": "拡張ブロックのパラメータのみをトレーニングします。",
}, },
}, },
"enable_thinking": {
"en": {
"label": "Enable thinking",
"info": "Whether or not to enable thinking mode for reasoning models.",
},
"ru": {
"label": "Включить мысли",
"info": "Включить режим мысли для моделей решающего характера.",
},
"zh": {
"label": "启用思考模式",
"info": "是否启用推理模型的思考模式。",
},
"ko": {
"label": "생각 모드 활성화",
"info": "추론 모델의 생각 모드를 활성화할지 여부.",
},
"ja": {
"label": "思考モードを有効化",
"info": "推論モデルの思考モードを有効にするかどうか。",
},
},
"report_to": { "report_to": {
"en": { "en": {
"label": "Enable external logger", "label": "Enable external logger",
...@@ -1374,6 +1396,177 @@ LOCALES = { ...@@ -1374,6 +1396,177 @@ LOCALES = {
"info": "PPO トレーニングにおいて報酬スコアをホワイトニング処理します。", "info": "PPO トレーニングにおいて報酬スコアをホワイトニング処理します。",
}, },
}, },
"mm_tab": {
"en": {
"label": "Multimodal configurations",
},
"ru": {
"label": "Конфигурации мультимедиа",
},
"zh": {
"label": "多模态参数设置",
},
"ko": {
"label": "멀티모달 구성",
},
"ja": {
"label": "多モーダル設定",
},
},
"freeze_vision_tower": {
"en": {
"label": "Freeze vision tower",
"info": "Freeze the vision tower in the model.",
},
"ru": {
"label": "Заморозить башню визиона",
"info": "Заморозить башню визиона в модели.",
},
"zh": {
"label": "冻结视觉编码器",
"info": "冻结模型中的视觉编码器。",
},
"ko": {
"label": "비전 타워 고정",
"info": "모델의 비전 타워를 고정합니다.",
},
"ja": {
"label": "ビジョンタワーの固定",
"info": "モデルのビジョンタワーを固定します。",
},
},
"freeze_multi_modal_projector": {
"en": {
"label": "Freeze multi-modal projector",
"info": "Freeze the multi-modal projector in the model.",
},
"ru": {
"label": "Заморозить мультимодальный проектор",
"info": "Заморозить мультимодальный проектор в модели.",
},
"zh": {
"label": "冻结多模态投影器",
"info": "冻结模型中的多模态投影器。",
},
"ko": {
"label": "멀티모달 프로젝터 고정",
"info": "모델의 멀티모달 프로젝터를 고정합니다.",
},
"ja": {
"label": "多モーダルプロジェクターの固定",
"info": "モデルの多モーダルプロジェクターを固定します。",
},
},
"freeze_language_model": {
"en": {
"label": "Freeze language model",
"info": "Freeze the language model in the model.",
},
"ru": {
"label": "Заморозить язык модели",
"info": "Заморозить язык модели в модели.",
},
"zh": {
"label": "冻结语言模型",
"info": "冻结模型中的语言模型。",
},
"ko": {
"label": "언어 모델 고정",
"info": "모델의 언어 모델을 고정합니다.",
},
"ja": {
"label": "言語モデルの固定",
"info": "モデルの言語モデルを固定します。",
},
},
"image_max_pixels": {
"en": {
"label": "Image max pixels",
"info": "The maximum number of pixels of image inputs.",
},
"ru": {
"label": "Максимальное количество пикселей изображения",
"info": "Максимальное количество пикселей изображения.",
},
"zh": {
"label": "图像最大像素",
"info": "输入图像的最大像素数。",
},
"ko": {
"label": "이미지 최대 픽셀",
"info": "이미지 입력의 최대 픽셀 수입니다.",
},
"ja": {
"label": "画像最大ピクセル",
"info": "画像入力の最大ピクセル数です。",
},
},
"image_min_pixels": {
"en": {
"label": "Image min pixels",
"info": "The minimum number of pixels of image inputs.",
},
"ru": {
"label": "Минимальное количество пикселей изображения",
"info": "Минимальное количество пикселей изображения.",
},
"zh": {
"label": "图像最小像素",
"info": "输入图像的最小像素数。",
},
"ko": {
"label": "이미지 최소 픽셀",
"info": "이미지 입력의 최소 픽셀 수입니다.",
},
"ja": {
"label": "画像最小ピクセル",
"info": "画像入力の最小ピクセル数です。",
},
},
"video_max_pixels": {
"en": {
"label": "Video max pixels",
"info": "The maximum number of pixels of video inputs.",
},
"ru": {
"label": "Максимальное количество пикселей видео",
"info": "Максимальное количество пикселей видео.",
},
"zh": {
"label": "视频最大像素",
"info": "输入视频的最大像素数。",
},
"ko": {
"label": "비디오 최대 픽셀",
"info": "비디오 입력의 최대 픽셀 수입니다.",
},
"ja": {
"label": "ビデオ最大ピクセル",
"info": "ビデオ入力の最大ピクセル数です。",
},
},
"video_min_pixels": {
"en": {
"label": "Video min pixels",
"info": "The minimum number of pixels of video inputs.",
},
"ru": {
"label": "Минимальное количество пикселей видео",
"info": "Минимальное количество пикселей видео.",
},
"zh": {
"label": "视频最小像素",
"info": "输入视频的最小像素数。",
},
"ko": {
"label": "비디오 최소 픽셀",
"info": "비디오 입력의 최소 픽셀 수입니다.",
},
"ja": {
"label": "ビデオ最小ピクセル",
"info": "ビデオ入力の最小ピクセル数です。",
},
},
"galore_tab": { "galore_tab": {
"en": { "en": {
"label": "GaLore configurations", "label": "GaLore configurations",
...@@ -2779,6 +2972,13 @@ ALERTS = { ...@@ -2779,6 +2972,13 @@ ALERTS = {
"ko": "출력 디렉토리가 이미 존재합니다. 위 출력 디렉토리에 저장된 학습을 재개합니다.", "ko": "출력 디렉토리가 이미 존재합니다. 위 출력 디렉토리에 저장된 학습을 재개합니다.",
"ja": "出力ディレクトリが既に存在します。このチェックポイントからトレーニングを再開します。", "ja": "出力ディレクトリが既に存在します。このチェックポイントからトレーニングを再開します。",
}, },
"warn_no_instruct": {
"en": "You are using a non-instruct model, please fine-tune it first.",
"ru": "Вы используете модель без инструкции, пожалуйста, primeros выполните донастройку этой модели.",
"zh": "您正在使用非指令模型,请先对其进行微调。",
"ko": "당신은 지시하지 않은 모델을 사용하고 있습니다. 먼저 이를 미세 조정해 주세요.",
"ja": "インストラクションモデルを使用していません。まずモデルをアダプターに適合させてください。",
},
"info_aborting": { "info_aborting": {
"en": "Aborted, wait for terminating...", "en": "Aborted, wait for terminating...",
"ru": "Прервано, ожидание завершения...", "ru": "Прервано, ожидание завершения...",
......
...@@ -29,6 +29,7 @@ from .common import ( ...@@ -29,6 +29,7 @@ from .common import (
DEFAULT_CACHE_DIR, DEFAULT_CACHE_DIR,
DEFAULT_CONFIG_DIR, DEFAULT_CONFIG_DIR,
abort_process, abort_process,
calculate_pixels,
gen_cmd, gen_cmd,
get_save_dir, get_save_dir,
load_args, load_args,
...@@ -162,7 +163,15 @@ class Runner: ...@@ -162,7 +163,15 @@ class Runner:
mask_history=get("train.mask_history"), mask_history=get("train.mask_history"),
resize_vocab=get("train.resize_vocab"), resize_vocab=get("train.resize_vocab"),
use_llama_pro=get("train.use_llama_pro"), use_llama_pro=get("train.use_llama_pro"),
enable_thinking=get("train.enable_thinking"),
report_to=get("train.report_to"), report_to=get("train.report_to"),
freeze_vision_tower=get("train.freeze_vision_tower"),
freeze_multi_modal_projector=get("train.freeze_multi_modal_projector"),
freeze_language_model=get("train.freeze_language_model"),
image_max_pixels=calculate_pixels(get("train.image_max_pixels")),
image_min_pixels=calculate_pixels(get("train.image_min_pixels")),
video_max_pixels=calculate_pixels(get("train.video_max_pixels")),
video_min_pixels=calculate_pixels(get("train.video_min_pixels")),
use_galore=get("train.use_galore"), use_galore=get("train.use_galore"),
use_apollo=get("train.use_apollo"), use_apollo=get("train.use_apollo"),
use_badam=get("train.use_badam"), use_badam=get("train.use_badam"),
...@@ -256,12 +265,6 @@ class Runner: ...@@ -256,12 +265,6 @@ class Runner:
args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio") args["badam_update_ratio"] = get("train.badam_update_ratio")
# report_to
if "none" in args["report_to"]:
args["report_to"] = "none"
elif "all" in args["report_to"]:
args["report_to"] = "all"
# swanlab config # swanlab config
if get("train.use_swanlab"): if get("train.use_swanlab"):
args["swanlab_project"] = get("train.swanlab_project") args["swanlab_project"] = get("train.swanlab_project")
......
...@@ -50,7 +50,7 @@ def test_function_formatter(): ...@@ -50,7 +50,7 @@ def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default") formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
"</s>", "</s>",
] ]
...@@ -60,7 +60,7 @@ def test_multi_function_formatter(): ...@@ -60,7 +60,7 @@ def test_multi_function_formatter():
tool_calls = json.dumps([FUNCTION] * 2) tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""" """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
"</s>", "</s>",
] ]
...@@ -85,7 +85,7 @@ def test_default_tool_formatter(): ...@@ -85,7 +85,7 @@ def test_default_tool_formatter():
def test_default_tool_extractor(): def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default") formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
...@@ -93,7 +93,7 @@ def test_default_multi_tool_extractor(): ...@@ -93,7 +93,7 @@ def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default") formatter = ToolFormatter(tool_format="default")
result = ( result = (
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n""" """Action: another_tool\nAction Input: {"foo": "job", "size": 2}"""
) )
assert formatter.extract(result) == [ assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""), ("test_tool", """{"foo": "bar", "size": 10}"""),
...@@ -125,12 +125,22 @@ def test_glm4_tool_extractor(): ...@@ -125,12 +125,22 @@ def test_glm4_tool_extractor():
def test_llama3_function_formatter(): def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>""" """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
] ]
def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
"""<|eot_id|>"""
]
def test_llama3_tool_formatter(): def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3") formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
...@@ -150,6 +160,18 @@ def test_llama3_tool_extractor(): ...@@ -150,6 +160,18 @@ def test_llama3_tool_extractor():
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = (
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_mistral_function_formatter(): def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral") formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
......
...@@ -135,8 +135,7 @@ def _check_plugin( ...@@ -135,8 +135,7 @@ def _check_plugin(
expected_mm_inputs: dict[str, Any] = {}, expected_mm_inputs: dict[str, Any] = {},
expected_no_mm_inputs: dict[str, Any] = {}, expected_no_mm_inputs: dict[str, Any] = {},
) -> None: ) -> None:
# test omni_messages if plugin.__class__.__name__ == "Qwen2OmniPlugin": # test omni_messages
if plugin.__class__.__name__ == "Qwen2OmniPlugin":
assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == ( assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == (
expected_input_ids, expected_input_ids,
...@@ -146,8 +145,7 @@ def _check_plugin( ...@@ -146,8 +145,7 @@ def _check_plugin(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor), plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
expected_mm_inputs, expected_mm_inputs,
) )
# test mm_messages elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages
if plugin.__class__.__name__ != "BasePlugin":
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids, expected_input_ids,
...@@ -201,7 +199,7 @@ def test_gemma3_plugin(): ...@@ -201,7 +199,7 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.") @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin(): def test_internvl_plugin():
image_seqlen = 256 image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="OpenGVLab/InternVL3-1B-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="OpenGVLab/InternVL3-1B-hf")
...@@ -219,7 +217,7 @@ def test_internvl_plugin(): ...@@ -219,7 +217,7 @@ def test_internvl_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.") @pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
def test_llama4_plugin(): def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4) tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
processor = tokenizer_module["processor"] processor = tokenizer_module["processor"]
...@@ -321,10 +319,9 @@ def test_pixtral_plugin(): ...@@ -321,10 +319,9 @@ def test_pixtral_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.") @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_qwen2_omni_plugin(): def test_qwen2_omni_plugin():
image_seqlen = 4 image_seqlen, audio_seqlen = 4, 2
audio_seqlen = 2
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B") tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B")
qwen2_omni_plugin = get_mm_plugin( qwen2_omni_plugin = get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
......
...@@ -125,6 +125,60 @@ def test_encode_multiturn(use_fast: bool): ...@@ -125,6 +125,60 @@ def test_encode_multiturn(use_fast: bool):
) )
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
input_messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, input_messages)
output_messages = MESSAGES if enable_thinking is False else input_messages
prompt_str = (
f"<|im_start|>user\n{output_messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
f"{MESSAGES[1]['content']}<|im_end|>\n"
f"<|im_start|>user\n{output_messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
answer_str = f"{output_messages[3]['content']}<|im_end|>\n"
if not cot_messages or enable_thinking is False:
if enable_thinking:
answer_str = "<think>\n\n</think>\n\n" + answer_str
else:
prompt_str = prompt_str + "<think>\n\n</think>\n\n"
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
input_messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
encoded_pairs = template.encode_multiturn(tokenizer, input_messages)
output_messages = MESSAGES if enable_thinking is False else input_messages
prompt_str_1 = f"<|im_start|>user\n{output_messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
answer_str_1 = f"{output_messages[1]['content']}<|im_end|>\n"
prompt_str_2 = f"<|im_start|>user\n{output_messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
answer_str_2 = f"{output_messages[3]['content']}<|im_end|>\n"
if not cot_messages or enable_thinking is False:
if enable_thinking:
answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
else:
prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
)
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool): def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
...@@ -162,12 +216,12 @@ def test_get_stop_token_ids(): ...@@ -162,12 +216,12 @@ def test_get_stop_token_ids():
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool): def test_gemma_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n" f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
"<start_of_turn>model\nI am fine!<end_of_turn>\n" f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
"<start_of_turn>user\n你好<end_of_turn>\n" f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
"<start_of_turn>model\n" "<start_of_turn>model\n"
) )
answer_str = "很高兴认识你!<end_of_turn>\n" answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast) _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
...@@ -175,12 +229,12 @@ def test_gemma_template(use_fast: bool): ...@@ -175,12 +229,12 @@ def test_gemma_template(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool): def test_llama3_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>" f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n" "<|start_header_id|>assistant<|end_header_id|>\n\n"
) )
answer_str = "很高兴认识你!<|eot_id|>" answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast) _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
...@@ -189,12 +243,12 @@ def test_llama3_template(use_fast: bool): ...@@ -189,12 +243,12 @@ def test_llama3_template(use_fast: bool):
) )
def test_llama4_template(use_fast: bool): def test_llama4_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|begin_of_text|><|header_start|>user<|header_end|>\n\nHow are you<|eot|>" f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
"<|header_start|>assistant<|header_end|>\n\nI am fine!<|eot|>" f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
"<|header_start|>user<|header_end|>\n\n你好<|eot|>" f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
"<|header_start|>assistant<|header_end|>\n\n" "<|header_start|>assistant<|header_end|>\n\n"
) )
answer_str = "很高兴认识你!<|eot|>" answer_str = f"{MESSAGES[3]['content']}<|eot|>"
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast) _check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
...@@ -203,12 +257,12 @@ def test_llama4_template(use_fast: bool): ...@@ -203,12 +257,12 @@ def test_llama4_template(use_fast: bool):
) )
def test_phi4_template(use_fast: bool): def test_phi4_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>user<|im_sep|>How are you<|im_end|>" f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>" f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
"<|im_start|>user<|im_sep|>你好<|im_end|>" f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
"<|im_start|>assistant<|im_sep|>" "<|im_start|>assistant<|im_sep|>"
) )
answer_str = "很高兴认识你!<|im_end|>" answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast) _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
...@@ -216,25 +270,30 @@ def test_phi4_template(use_fast: bool): ...@@ -216,25 +270,30 @@ def test_phi4_template(use_fast: bool):
def test_qwen2_5_template(use_fast: bool): def test_qwen2_5_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n" f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n" f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n" f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = "很高兴认识你!<|im_end|>\n" answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_qwen3_template(use_fast: bool): @pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
prompt_str = ( prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n" f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n" f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n" f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = "<think>\n模型思考内容\n</think>\n\n很高兴认识你!<|im_end|>\n" answer_str = f"{messages[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT) if not cot_messages:
answer_str = "<think>\n\n</think>\n\n" + answer_str
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
def test_parse_llama3_template(): def test_parse_llama3_template():
...@@ -253,6 +312,7 @@ def test_parse_llama3_template(): ...@@ -253,6 +312,7 @@ def test_parse_llama3_template():
def test_parse_qwen_template(): def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.__class__.__name__ == "Template"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"] assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"] assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
...@@ -263,6 +323,7 @@ def test_parse_qwen_template(): ...@@ -263,6 +323,7 @@ def test_parse_qwen_template():
def test_parse_qwen3_template(): def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.__class__.__name__ == "ReasoningTemplate"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"] assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"] assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment