Commit 0722acf1 authored by chenych's avatar chenych
Browse files

Update 0604

parent c4ba4563
......@@ -871,6 +871,28 @@ LOCALES = {
"info": "拡張ブロックのパラメータのみをトレーニングします。",
},
},
"enable_thinking": {
"en": {
"label": "Enable thinking",
"info": "Whether or not to enable thinking mode for reasoning models.",
},
"ru": {
"label": "Включить мысли",
"info": "Включить режим мысли для моделей решающего характера.",
},
"zh": {
"label": "启用思考模式",
"info": "是否启用推理模型的思考模式。",
},
"ko": {
"label": "생각 모드 활성화",
"info": "추론 모델의 생각 모드를 활성화할지 여부.",
},
"ja": {
"label": "思考モードを有効化",
"info": "推論モデルの思考モードを有効にするかどうか。",
},
},
"report_to": {
"en": {
"label": "Enable external logger",
......@@ -1374,6 +1396,177 @@ LOCALES = {
"info": "PPO トレーニングにおいて報酬スコアをホワイトニング処理します。",
},
},
"mm_tab": {
"en": {
"label": "Multimodal configurations",
},
"ru": {
"label": "Конфигурации мультимедиа",
},
"zh": {
"label": "多模态参数设置",
},
"ko": {
"label": "멀티모달 구성",
},
"ja": {
"label": "多モーダル設定",
},
},
"freeze_vision_tower": {
"en": {
"label": "Freeze vision tower",
"info": "Freeze the vision tower in the model.",
},
"ru": {
"label": "Заморозить башню визиона",
"info": "Заморозить башню визиона в модели.",
},
"zh": {
"label": "冻结视觉编码器",
"info": "冻结模型中的视觉编码器。",
},
"ko": {
"label": "비전 타워 고정",
"info": "모델의 비전 타워를 고정합니다.",
},
"ja": {
"label": "ビジョンタワーの固定",
"info": "モデルのビジョンタワーを固定します。",
},
},
"freeze_multi_modal_projector": {
"en": {
"label": "Freeze multi-modal projector",
"info": "Freeze the multi-modal projector in the model.",
},
"ru": {
"label": "Заморозить мультимодальный проектор",
"info": "Заморозить мультимодальный проектор в модели.",
},
"zh": {
"label": "冻结多模态投影器",
"info": "冻结模型中的多模态投影器。",
},
"ko": {
"label": "멀티모달 프로젝터 고정",
"info": "모델의 멀티모달 프로젝터를 고정합니다.",
},
"ja": {
"label": "多モーダルプロジェクターの固定",
"info": "モデルの多モーダルプロジェクターを固定します。",
},
},
"freeze_language_model": {
"en": {
"label": "Freeze language model",
"info": "Freeze the language model in the model.",
},
"ru": {
"label": "Заморозить язык модели",
"info": "Заморозить язык модели в модели.",
},
"zh": {
"label": "冻结语言模型",
"info": "冻结模型中的语言模型。",
},
"ko": {
"label": "언어 모델 고정",
"info": "모델의 언어 모델을 고정합니다.",
},
"ja": {
"label": "言語モデルの固定",
"info": "モデルの言語モデルを固定します。",
},
},
"image_max_pixels": {
"en": {
"label": "Image max pixels",
"info": "The maximum number of pixels of image inputs.",
},
"ru": {
"label": "Максимальное количество пикселей изображения",
"info": "Максимальное количество пикселей изображения.",
},
"zh": {
"label": "图像最大像素",
"info": "输入图像的最大像素数。",
},
"ko": {
"label": "이미지 최대 픽셀",
"info": "이미지 입력의 최대 픽셀 수입니다.",
},
"ja": {
"label": "画像最大ピクセル",
"info": "画像入力の最大ピクセル数です。",
},
},
"image_min_pixels": {
"en": {
"label": "Image min pixels",
"info": "The minimum number of pixels of image inputs.",
},
"ru": {
"label": "Минимальное количество пикселей изображения",
"info": "Минимальное количество пикселей изображения.",
},
"zh": {
"label": "图像最小像素",
"info": "输入图像的最小像素数。",
},
"ko": {
"label": "이미지 최소 픽셀",
"info": "이미지 입력의 최소 픽셀 수입니다.",
},
"ja": {
"label": "画像最小ピクセル",
"info": "画像入力の最小ピクセル数です。",
},
},
"video_max_pixels": {
"en": {
"label": "Video max pixels",
"info": "The maximum number of pixels of video inputs.",
},
"ru": {
"label": "Максимальное количество пикселей видео",
"info": "Максимальное количество пикселей видео.",
},
"zh": {
"label": "视频最大像素",
"info": "输入视频的最大像素数。",
},
"ko": {
"label": "비디오 최대 픽셀",
"info": "비디오 입력의 최대 픽셀 수입니다.",
},
"ja": {
"label": "ビデオ最大ピクセル",
"info": "ビデオ入力の最大ピクセル数です。",
},
},
"video_min_pixels": {
"en": {
"label": "Video min pixels",
"info": "The minimum number of pixels of video inputs.",
},
"ru": {
"label": "Минимальное количество пикселей видео",
"info": "Минимальное количество пикселей видео.",
},
"zh": {
"label": "视频最小像素",
"info": "输入视频的最小像素数。",
},
"ko": {
"label": "비디오 최소 픽셀",
"info": "비디오 입력의 최소 픽셀 수입니다.",
},
"ja": {
"label": "ビデオ最小ピクセル",
"info": "ビデオ入力の最小ピクセル数です。",
},
},
"galore_tab": {
"en": {
"label": "GaLore configurations",
......@@ -2779,6 +2972,13 @@ ALERTS = {
"ko": "출력 디렉토리가 이미 존재합니다. 위 출력 디렉토리에 저장된 학습을 재개합니다.",
"ja": "出力ディレクトリが既に存在します。このチェックポイントからトレーニングを再開します。",
},
"warn_no_instruct": {
"en": "You are using a non-instruct model, please fine-tune it first.",
"ru": "Вы используете модель без инструкции, пожалуйста, primeros выполните донастройку этой модели.",
"zh": "您正在使用非指令模型,请先对其进行微调。",
"ko": "당신은 지시하지 않은 모델을 사용하고 있습니다. 먼저 이를 미세 조정해 주세요.",
"ja": "インストラクションモデルを使用していません。まずモデルをアダプターに適合させてください。",
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"ru": "Прервано, ожидание завершения...",
......
......@@ -22,13 +22,14 @@ from typing import TYPE_CHECKING, Any, Optional
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_npu_available
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_accelerator_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available
from .common import (
DEFAULT_CACHE_DIR,
DEFAULT_CONFIG_DIR,
abort_process,
calculate_pixels,
gen_cmd,
get_save_dir,
load_args,
......@@ -162,6 +163,7 @@ class Runner:
mask_history=get("train.mask_history"),
resize_vocab=get("train.resize_vocab"),
use_llama_pro=get("train.use_llama_pro"),
enable_thinking=get("train.enable_thinking"),
report_to=get("train.report_to"),
use_galore=get("train.use_galore"),
use_apollo=get("train.use_apollo"),
......@@ -235,6 +237,16 @@ class Runner:
args["pref_ftx"] = get("train.pref_ftx")
args["pref_loss"] = get("train.pref_loss")
# multimodal config
if model_name in MULTIMODAL_SUPPORTED_MODELS:
args["freeze_vision_tower"] = get("train.freeze_vision_tower")
args["freeze_multi_modal_projector"] = get("train.freeze_multi_modal_projector")
args["freeze_language_model"] = get("train.freeze_language_model")
args["image_max_pixels"] = calculate_pixels(get("train.image_max_pixels"))
args["image_min_pixels"] = calculate_pixels(get("train.image_min_pixels"))
args["video_max_pixels"] = calculate_pixels(get("train.video_max_pixels"))
args["video_min_pixels"] = calculate_pixels(get("train.video_min_pixels"))
# galore config
if args["use_galore"]:
args["galore_rank"] = get("train.galore_rank")
......@@ -256,12 +268,6 @@ class Runner:
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")
# report_to
if "none" in args["report_to"]:
args["report_to"] = "none"
elif "all" in args["report_to"]:
args["report_to"] = "all"
# swanlab config
if get("train.use_swanlab"):
args["swanlab_project"] = get("train.swanlab_project")
......
......@@ -50,7 +50,7 @@ def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
"</s>",
]
......@@ -60,7 +60,7 @@ def test_multi_function_formatter():
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
"</s>",
]
......@@ -85,7 +85,7 @@ def test_default_tool_formatter():
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
......@@ -93,7 +93,7 @@ def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
......@@ -125,12 +125,22 @@ def test_glm4_tool_extractor():
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
"""<|eot_id|>"""
]
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
......@@ -150,6 +160,18 @@ def test_llama3_tool_extractor():
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = (
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
......
......@@ -135,8 +135,7 @@ def _check_plugin(
expected_mm_inputs: dict[str, Any] = {},
expected_no_mm_inputs: dict[str, Any] = {},
) -> None:
# test omni_messages
if plugin.__class__.__name__ == "Qwen2OmniPlugin":
if plugin.__class__.__name__ == "Qwen2OmniPlugin": # test omni_messages
assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == (
expected_input_ids,
......@@ -146,8 +145,7 @@ def _check_plugin(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
# test mm_messages
if plugin.__class__.__name__ != "BasePlugin":
elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids,
......@@ -201,7 +199,7 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="OpenGVLab/InternVL3-1B-hf")
......@@ -219,7 +217,7 @@ def test_internvl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
processor = tokenizer_module["processor"]
......@@ -321,10 +319,9 @@ def test_pixtral_plugin():
_check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_qwen2_omni_plugin():
image_seqlen = 4
audio_seqlen = 2
image_seqlen, audio_seqlen = 4, 2
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B")
qwen2_omni_plugin = get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
......
......@@ -125,6 +125,61 @@ def test_encode_multiturn(use_fast: bool):
)
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
prompt_str = (
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
f"{MESSAGES[1]['content']}<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
if not cot_messages or enable_thinking is False:
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
if enable_thinking:
answer_str = "<think>\n\n</think>\n\n" + answer_str
else:
prompt_str = prompt_str + "<think>\n\n</think>\n\n"
else:
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
messages = MESSAGES if not cot_messages or enable_thinking is False else MESSAGES_WITH_THOUGHT
prompt_str_1 = f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
prompt_str_2 = f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
if not cot_messages or enable_thinking is False:
if enable_thinking:
answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
else:
prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
)
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
......@@ -162,12 +217,12 @@ def test_get_stop_token_ids():
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
prompt_str = (
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
"<start_of_turn>user\n你好<end_of_turn>\n"
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
"<start_of_turn>model\n"
)
answer_str = "很高兴认识你!<end_of_turn>\n"
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
......@@ -175,12 +230,12 @@ def test_gemma_template(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str = "很高兴认识你!<|eot_id|>"
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
......@@ -189,52 +244,64 @@ def test_llama3_template(use_fast: bool):
)
def test_llama4_template(use_fast: bool):
prompt_str = (
"<|begin_of_text|><|header_start|>user<|header_end|>\n\nHow are you<|eot|>"
"<|header_start|>assistant<|header_end|>\n\nI am fine!<|eot|>"
"<|header_start|>user<|header_end|>\n\n你好<|eot|>"
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
"<|header_start|>assistant<|header_end|>\n\n"
)
answer_str = "很高兴认识你!<|eot|>"
answer_str = f"{MESSAGES[3]['content']}<|eot|>"
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
"use_fast",
[
pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
],
)
def test_phi4_template(use_fast: bool):
prompt_str = (
"<|im_start|>user<|im_sep|>How are you<|im_end|>"
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>"
"<|im_start|>user<|im_sep|>你好<|im_end|>"
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
answer_str = "很高兴认识你!<|im_end|>"
answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen2_5_template(use_fast: bool):
prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "很高兴认识你!<|im_end|>\n"
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen3_template(use_fast: bool):
@pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "<think>\n模型思考内容\n</think>\n\n很高兴认识你!<|im_end|>\n"
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT)
if not cot_messages:
answer_str = f"<think>\n\n</think>\n\n{MESSAGES[3]['content']}<|im_end|>\n"
messages = MESSAGES
else:
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
messages = MESSAGES_WITH_THOUGHT
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
def test_parse_llama3_template():
......@@ -250,9 +317,11 @@ def test_parse_llama3_template():
assert template.default_system == ""
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.__class__.__name__ == "Template"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
......@@ -260,9 +329,11 @@ def test_parse_qwen_template():
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.__class__.__name__ == "ReasoningTemplate"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
......
......@@ -20,7 +20,7 @@ from llamafactory.chat import ChatModel
from llamafactory.extras.packages import is_sglang_available
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
INFER_ARGS = {
......
......@@ -16,6 +16,7 @@ import pytest
import torch
from transformers import AutoConfig, AutoModelForVision2Seq
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import FinetuningArguments, ModelArguments
from llamafactory.model.adapter import init_adapter
......@@ -45,10 +46,12 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
assert param.requires_grad != freeze_language_model
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
def test_visual_lora(freeze_vision_tower: bool):
@pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False)))
def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower)
finetuning_args = FinetuningArguments(
finetuning_type="lora", freeze_vision_tower=freeze_vision_tower, freeze_language_model=freeze_language_model
)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config)
......@@ -61,10 +64,15 @@ def test_visual_lora(freeze_vision_tower: bool):
else:
frozen_params.add(name)
if freeze_vision_tower:
assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" not in trainable_params
if is_transformers_version_greater_than("4.52.0"):
visual_param_name = "base_model.model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
language_param_name = "base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.default.weight"
merger_param_name = "base_model.model.model.visual.merger.lora_A.default.weight"
else:
assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" in trainable_params
visual_param_name = "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
language_param_name = "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"
merger_param_name = "base_model.model.visual.merger.lora_A.default.weight"
assert "merger" not in trainable_params
assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight" in trainable_params
assert (visual_param_name in trainable_params) != freeze_vision_tower
assert (language_param_name in trainable_params) != freeze_language_model
assert (merger_param_name in trainable_params) is False
# change if test fails or cache is outdated
0.9.3.106
0.9.3.107
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment