Commit c7d1b209 authored by chenych's avatar chenych
Browse files

Update 0429

parent c8d12c06
...@@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float: ...@@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float:
def main(): def main():
client = OpenAI( client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")), api_key="{}".format(os.getenv("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
) )
tools = [ tools = [
{ {
......
...@@ -32,7 +32,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso ...@@ -32,7 +32,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict() baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"): for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu", weights_only=True)
baichuan2_state_dict.update(shard_weight) baichuan2_state_dict.update(shard_weight)
llama_state_dict: dict[str, torch.Tensor] = OrderedDict() llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
......
...@@ -17,23 +17,8 @@ r"""Efficient fine-tuning of large language models. ...@@ -17,23 +17,8 @@ r"""Efficient fine-tuning of large language models.
Level: Level:
api, webui > chat, eval, train > data, model > hparams > extras api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.5.0
accelerate>=0.34.0,<=1.6.0
peft>=0.14.0,<=0.15.1
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<4.48.0
packing:
transformers>=4.43.0
Disable version checking: DISABLE_VERSION_CHECK=1 Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1 Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1 Use modelscope: USE_MODELSCOPE_HUB=1
......
...@@ -25,7 +25,6 @@ from typing_extensions import override ...@@ -25,7 +25,6 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
...@@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine): ...@@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine):
inputs=inputs, inputs=inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args), generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
) )
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
......
...@@ -16,17 +16,7 @@ import os ...@@ -16,17 +16,7 @@ import os
import subprocess import subprocess
import sys import sys
from copy import deepcopy from copy import deepcopy
from enum import Enum, unique from functools import partial
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
USAGE = ( USAGE = (
...@@ -44,7 +34,21 @@ USAGE = ( ...@@ -44,7 +34,21 @@ USAGE = (
+ "-" * 70 + "-" * 70
) )
WELCOME = (
def main():
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}" + f"| Welcome to LLaMA Factory, version {VERSION}"
...@@ -54,40 +58,24 @@ WELCOME = ( ...@@ -54,40 +58,24 @@ WELCOME = (
+ "|\n" + "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58 + "-" * 58
) )
logger = logging.get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
COMMAND_MAP = {
"api": run_api,
"chat": run_chat,
"env": print_env,
"eval": run_eval,
"export": export_model,
"train": run_exp,
"webchat": run_web_demo,
"webui": run_web_ui,
"version": partial(print, WELCOME),
"help": partial(print, USAGE),
}
def main(): command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
if command == Command.API: # launch distributed training
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()):
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0") node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
...@@ -123,19 +111,14 @@ def main(): ...@@ -123,19 +111,14 @@ def main():
check=True, check=True,
) )
sys.exit(process.returncode) sys.exit(process.returncode)
else: elif command in COMMAND_MAP:
run_exp() COMMAND_MAP[command]()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
else: else:
print(f"Unknown command: {command}.\n{USAGE}") print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__": if __name__ == "__main__":
from multiprocessing import freeze_support
freeze_support()
main() main()
...@@ -176,7 +176,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -176,7 +176,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"input_ids": features["input_ids"], "input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"), "image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"), "video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"], "attention_mask": (features["attention_mask"] >= 1).float(),
} }
if "second_per_grid_ts" in mm_inputs: # for qwen2vl if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
......
This diff is collapsed.
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.misc import check_version
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin from .mm_plugin import get_mm_plugin
...@@ -61,7 +61,7 @@ class Template: ...@@ -61,7 +61,7 @@ class Template:
tools: Optional[str] = None, tools: Optional[str] = None,
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively.""" r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
prompt_ids = [] prompt_ids = []
for encoded_ids in encoded_messages[:-1]: for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids prompt_ids += encoded_ids
...@@ -77,7 +77,7 @@ class Template: ...@@ -77,7 +77,7 @@ class Template:
tools: Optional[str] = None, tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]: ) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively.""" r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
...@@ -111,12 +111,18 @@ class Template: ...@@ -111,12 +111,18 @@ class Template:
return token_ids return token_ids
def _remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
remove_thought: bool,
) -> list[list[int]]: ) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids. r"""Encode formatted inputs to pairs of token ids.
...@@ -134,14 +140,18 @@ class Template: ...@@ -134,14 +140,18 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
if message["role"] == Role.USER.value: content = message["content"]
elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
elif message["role"] == Role.ASSISTANT.value: content = self._remove_thought(content)
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value: if message["role"] == Role.USER:
elements += self.format_observation.apply(content=message["content"]) elements += self.format_user.apply(content=content, idx=str(i // 2))
elif message["role"] == Role.FUNCTION.value: elif message["role"] == Role.ASSISTANT:
elements += self.format_function.apply(content=message["content"]) elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
...@@ -318,6 +328,7 @@ class Llama2Template(Template): ...@@ -318,6 +328,7 @@ class Llama2Template(Template):
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: str, system: str,
tools: str, tools: str,
remove_thought: bool,
) -> list[list[int]]: ) -> list[list[int]]:
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
...@@ -331,14 +342,18 @@ class Llama2Template(Template): ...@@ -331,14 +342,18 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
if message["role"] == Role.USER.value: content = message["content"]
elements += self.format_user.apply(content=system_text + message["content"]) if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
elif message["role"] == Role.ASSISTANT.value: content = self._remove_thought(content)
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value: if message["role"] == Role.USER:
elements += self.format_observation.apply(content=message["content"]) elements += self.format_user.apply(content=system_text + content)
elif message["role"] == Role.FUNCTION.value: elif message["role"] == Role.ASSISTANT:
elements += self.format_function.apply(content=message["content"]) elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
...@@ -477,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -477,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system): if len(user_slot) > len(user_slot_empty_system):
default_system = find_diff(user_slot_empty_system, user_slot) default_system = find_diff(user_slot_empty_system, user_slot)
...@@ -518,9 +534,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ...@@ -518,9 +534,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template = TEMPLATES[data_args.template] template = TEMPLATES[data_args.template]
if template.mm_plugin.__class__.__name__ != "BasePlugin":
check_version("transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
...@@ -871,6 +884,18 @@ register_template( ...@@ -871,6 +884,18 @@ register_template(
) )
register_template(
name="granite3_vision",
format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
register_template( register_template(
name="index", name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]), format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
...@@ -923,6 +948,20 @@ register_template( ...@@ -923,6 +948,20 @@ register_template(
) )
register_template(
name="intern_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
),
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="intern_vl", image_token="<image>", video_token="<video>"),
)
register_template( register_template(
name="kimi_vl", name="kimi_vl",
format_user=StringFormatter( format_user=StringFormatter(
...@@ -1389,6 +1428,21 @@ register_template( ...@@ -1389,6 +1428,21 @@ register_template(
) )
# copied from qwen template
register_template(
name="qwen3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
)
# copied from chatml template # copied from chatml template
register_template( register_template(
name="qwen2_audio", name="qwen2_audio",
......
...@@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME ...@@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
AUDIO_PLACEHOLDER = os.environ.get("AUDIO_PLACEHOLDER", "<audio>") AUDIO_PLACEHOLDER = os.getenv("AUDIO_PLACEHOLDER", "<audio>")
CHECKPOINT_NAMES = { CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME, SAFE_ADAPTER_WEIGHTS_NAME,
...@@ -50,7 +50,7 @@ FILEEXT2TYPE = { ...@@ -50,7 +50,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>") IMAGE_PLACEHOLDER = os.getenv("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
...@@ -89,7 +89,7 @@ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} ...@@ -89,7 +89,7 @@ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SWANLAB_CONFIG = "swanlab_public_config.json" SWANLAB_CONFIG = "swanlab_public_config.json"
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>") VIDEO_PLACEHOLDER = os.getenv("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
...@@ -838,11 +838,46 @@ register_model_group( ...@@ -838,11 +838,46 @@ register_model_group(
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-instruct", DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-instruct", DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-instruct",
}, },
"Granite-3.2-2B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.2-2b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.2-2b-instruct",
},
"Granite-3.2-8B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.2-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.2-8b-instruct",
},
"Granite-3.3-2B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-2b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-2b-base",
},
"Granite-3.3-8B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-8b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-8b-base",
},
"Granite-3.3-2B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-2b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-2b-instruct",
},
"Granite-3.3-8B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-8b-instruct",
},
}, },
template="granite3", template="granite3",
) )
register_model_group(
models={
"Granite-3.2-1B-A400M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
},
},
template="granite3_vision",
)
register_model_group( register_model_group(
models={ models={
"Hunyuan-7B-Instruct": { "Hunyuan-7B-Instruct": {
...@@ -965,6 +1000,46 @@ register_model_group( ...@@ -965,6 +1000,46 @@ register_model_group(
) )
register_model_group(
models={
"InternVL2.5-2B-MPO": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL2_5-2B-MPO-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL2_5-2B-MPO-hf",
},
"InternVL2.5-8B-MPO": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL2_5-8B-MPO-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL2_5-8B-MPO-hf",
},
"InternVL3-1B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-1B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-1B-hf",
},
"InternVL3-2B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-2B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-2B-hf",
},
"InternVL3-8B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-8B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-8B-hf",
},
"InternVL3-14B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-14B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-14B-hf",
},
"InternVL3-38B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-38B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-38B-hf",
},
"InternVL3-78B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-78B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-78B-hf",
},
},
template="intern_vl",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"Jamba-v0.1": { "Jamba-v0.1": {
...@@ -2328,6 +2403,69 @@ register_model_group( ...@@ -2328,6 +2403,69 @@ register_model_group(
) )
register_model_group(
models={
"Qwen3-0.6B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-Base",
},
"Qwen3-1.7B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-Base",
},
"Qwen3-4B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-Base",
},
"Qwen3-8B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-Base",
},
"Qwen3-14B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-Base",
},
"Qwen3-30B-A3B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Base",
},
"Qwen3-0.6B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B",
},
"Qwen3-1.7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B",
},
"Qwen3-4B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B",
},
"Qwen3-8B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B",
},
"Qwen3-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B",
},
"Qwen3-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B",
},
"Qwen3-30B-A3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B",
},
"Qwen3-235B-A22B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
},
},
template="qwen3",
)
register_model_group( register_model_group(
models={ models={
"Qwen2-Audio-7B": { "Qwen2-Audio-7B": {
......
...@@ -79,7 +79,7 @@ class _Logger(logging.Logger): ...@@ -79,7 +79,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level": def _get_default_logging_level() -> "logging._Level":
r"""Return the default logging level.""" r"""Return the default logging level."""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None) env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
if env_level_str: if env_level_str:
if env_level_str.upper() in logging._nameToLevel: if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()] return logging._nameToLevel[env_level_str.upper()]
......
...@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None: ...@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.5.0") check_version("datasets>=2.16.0,<=3.5.0")
check_version("accelerate>=0.34.0,<=1.6.0") check_version("accelerate>=0.34.0,<=1.6.0")
check_version("peft>=0.14.0,<=0.15.1") check_version("peft>=0.14.0,<=0.15.1")
...@@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]: ...@@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
def get_current_device() -> "torch.device": def get_current_device() -> "torch.device":
r"""Get the current available device.""" r"""Get the current available device."""
if is_torch_xpu_available(): if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_npu_available(): elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_mps_available(): elif is_torch_mps_available():
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_cuda_available(): elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
else: else:
device = "cpu" device = "cpu"
...@@ -155,11 +155,13 @@ def get_current_device() -> "torch.device": ...@@ -155,11 +155,13 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int: def get_device_count() -> int:
r"""Get the number of available GPU or NPU devices.""" r"""Get the number of available devices."""
if is_torch_xpu_available(): if is_torch_xpu_available():
return torch.xpu.device_count() return torch.xpu.device_count()
elif is_torch_npu_available(): elif is_torch_npu_available():
return torch.npu.device_count() return torch.npu.device_count()
elif is_torch_mps_available():
return torch.mps.device_count()
elif is_torch_cuda_available(): elif is_torch_cuda_available():
return torch.cuda.device_count() return torch.cuda.device_count()
else: else:
...@@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList": ...@@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList":
def get_peak_memory() -> tuple[int, int]: def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes).""" r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available(): if is_torch_xpu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved() return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_mps_available():
return torch.mps.current_allocated_memory(), -1
elif is_torch_cuda_available(): elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else: else:
...@@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": ...@@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
return torch.float32 return torch.float32
def is_gpu_or_npu_available() -> bool: def is_accelerator_available() -> bool:
r"""Check if the GPU or NPU is available.""" r"""Check if the accelerator is available."""
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available() return (
is_torch_xpu_available() or is_torch_npu_available() or is_torch_mps_available() or is_torch_cuda_available()
)
def is_env_enabled(env_var: str, default: str = "0") -> bool: def is_env_enabled(env_var: str, default: str = "0") -> bool:
...@@ -229,7 +235,7 @@ def skip_check_imports() -> None: ...@@ -229,7 +235,7 @@ def skip_check_imports() -> None:
def torch_gc() -> None: def torch_gc() -> None:
r"""Collect GPU or NPU memory.""" r"""Collect the device memory."""
gc.collect() gc.collect()
if is_torch_xpu_available(): if is_torch_xpu_available():
torch.xpu.empty_cache() torch.xpu.empty_cache()
...@@ -280,7 +286,7 @@ def use_ray() -> bool: ...@@ -280,7 +286,7 @@ def use_ray() -> bool:
def find_available_port() -> int: def find_available_port() -> int:
"""Find an available port on the local machine.""" r"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0)) sock.bind(("", 0))
port = sock.getsockname()[1] port = sock.getsockname()[1]
...@@ -288,8 +294,8 @@ def find_available_port() -> int: ...@@ -288,8 +294,8 @@ def find_available_port() -> int:
return port return port
def fix_proxy(ipv6_enabled: bool) -> None: def fix_proxy(ipv6_enabled: bool = False) -> None:
"""Fix proxy settings for gradio ui.""" r"""Fix proxy settings for gradio ui."""
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0" os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
if ipv6_enabled: if ipv6_enabled:
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
......
...@@ -411,6 +411,10 @@ class FinetuningArguments( ...@@ -411,6 +411,10 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."}, metadata={"help": "Whether or not to use the Adam-mini optimizer."},
) )
use_muon: bool = field(
default=False,
metadata={"help": "Whether or not to use the Muon optimizer."},
)
freeze_vision_tower: bool = field( freeze_vision_tower: bool = field(
default=True, default=True,
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."}, metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
...@@ -431,6 +435,10 @@ class FinetuningArguments( ...@@ -431,6 +435,10 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."}, metadata={"help": "Whether or not to disable the shuffling of the training set."},
) )
early_stopping_steps: Optional[int] = field(
default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
)
plot_loss: bool = field( plot_loss: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to save the training loss curves."}, metadata={"help": "Whether or not to save the training loss curves."},
......
...@@ -65,7 +65,13 @@ class BaseModelArguments: ...@@ -65,7 +65,13 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
new_special_tokens: Optional[str] = field( add_tokens: Optional[str] = field(
default=None,
metadata={
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
add_special_tokens: Optional[str] = field(
default=None, default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
) )
...@@ -176,8 +182,11 @@ class BaseModelArguments: ...@@ -176,8 +182,11 @@ class BaseModelArguments:
if self.adapter_name_or_path is not None: # support merging multiple lora weights if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens if self.add_tokens is not None: # support multiple tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
if self.add_special_tokens is not None: # support multiple special tokens
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
@dataclass @dataclass
...@@ -222,6 +231,10 @@ class ProcessorArguments: ...@@ -222,6 +231,10 @@ class ProcessorArguments:
default=False, default=False,
metadata={"help": "Use pan and scan to process image for gemma3."}, metadata={"help": "Use pan and scan to process image for gemma3."},
) )
crop_to_patches: bool = field(
default=False,
metadata={"help": "Whether to crop the image to patches for internvl."},
)
use_audio_in_video: bool = field( use_audio_in_video: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use audio in video inputs."}, metadata={"help": "Whether or not to use audio in video inputs."},
......
...@@ -24,6 +24,7 @@ from typing import Any, Optional, Union ...@@ -24,6 +24,7 @@ from typing import Any, Optional, Union
import torch import torch
import transformers import transformers
import yaml import yaml
from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
...@@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[ ...@@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if args is not None: if args is not None:
return args return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")): if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) override_config = OmegaConf.from_cli(sys.argv[2:])
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
return json.loads(Path(sys.argv[1]).absolute().read_text()) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = json.loads(Path(sys.argv[1]).absolute().read_text())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else: else:
return sys.argv[1:] return sys.argv[1:]
...@@ -91,6 +96,14 @@ def _set_transformers_logging() -> None: ...@@ -91,6 +96,14 @@ def _set_transformers_logging() -> None:
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
def _set_env_vars() -> None:
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _verify_model_args( def _verify_model_args(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
...@@ -279,12 +292,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -279,12 +292,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm": if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled(): if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
...@@ -321,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -321,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.") logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments # Post-process training arguments
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
if finetuning_args.finetuning_type == "lora":
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args.label_names = training_args.label_names or ["labels"]
if ( if (
training_args.parallel_mode == ParallelMode.DISTRIBUTED training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
): ):
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
...@@ -407,6 +429,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -407,6 +429,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.") raise ValueError("vLLM only accepts a single adapter. Merge them first.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)
...@@ -428,9 +451,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E ...@@ -428,9 +451,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
_set_transformers_logging() _set_transformers_logging()
# Check arguments # Check arguments
if model_args.infer_backend == "vllm": if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)
......
...@@ -34,6 +34,10 @@ class RayArguments: ...@@ -34,6 +34,10 @@ class RayArguments:
default="./saves", default="./saves",
metadata={"help": "The storage path to save training results to"}, metadata={"help": "The storage path to save training results to"},
) )
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
ray_num_workers: int = field( ray_num_workers: int = field(
default=1, default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
...@@ -55,6 +59,17 @@ class RayArguments: ...@@ -55,6 +59,17 @@ class RayArguments:
self.use_ray = use_ray() self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
)
import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3":
self.ray_storage_filesystem = fs.S3FileSystem()
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
self.ray_storage_filesystem = fs.GcsFileSystem()
@dataclass @dataclass
......
...@@ -23,7 +23,7 @@ from ..extras import logging ...@@ -23,7 +23,7 @@ from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.visual import get_forbidden_modules, patch_target_modules from .model_utils.visual import COMPOSITE_MODELS, get_forbidden_modules, patch_target_modules
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -100,7 +100,7 @@ def _setup_freeze_tuning( ...@@ -100,7 +100,7 @@ def _setup_freeze_tuning(
hidden_modules.add(name.split(".1.")[-1].split(".")[0]) hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None: if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2]) non_hidden_modules.add(name.split(".")[-2]) # remove weight/bias
trainable_layers = [] trainable_layers = []
for module_name in finetuning_args.freeze_trainable_modules: for module_name in finetuning_args.freeze_trainable_modules:
...@@ -121,6 +121,10 @@ def _setup_freeze_tuning( ...@@ -121,6 +121,10 @@ def _setup_freeze_tuning(
trainable_layers.append(module_name) trainable_layers.append(module_name)
model_type = getattr(model.config, "model_type", None)
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
...@@ -204,7 +208,7 @@ def _setup_lora_tuning( ...@@ -204,7 +208,7 @@ def _setup_lora_tuning(
if ( if (
finetuning_args.use_dora finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES and getattr(model, "quantization_method", None) != QuantizationMethod.BNB
): ):
raise ValueError("DoRA is not compatible with PTQ-quantized models.") raise ValueError("DoRA is not compatible with PTQ-quantized models.")
......
...@@ -19,7 +19,6 @@ import torch ...@@ -19,7 +19,6 @@ import torch
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform, AutoModelForTextToWaveform,
AutoModelForVision2Seq, AutoModelForVision2Seq,
...@@ -30,6 +29,7 @@ from trl import AutoModelForCausalLMWithValueHead ...@@ -30,6 +29,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import is_transformers_version_greater_than
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
...@@ -39,6 +39,10 @@ from .model_utils.valuehead import load_valuehead_params ...@@ -39,6 +39,10 @@ from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
if is_transformers_version_greater_than("4.46.0"):
from transformers import AutoModelForImageTextToText
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
...@@ -97,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -97,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, tokenizer, model_args) patch_processor(processor, tokenizer, model_args)
except Exception as e: except Exception as e:
logger.debug(f"Failed to load processor: {e}.") logger.info_rank0(f"Failed to load processor: {e}.")
processor = None processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
...@@ -145,7 +149,10 @@ def load_model( ...@@ -145,7 +149,10 @@ def load_model(
else: else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text elif (
is_transformers_version_greater_than("4.46.0")
and type(config) in AutoModelForImageTextToText._model_mapping.keys()
): # image-text
load_class = AutoModelForImageTextToText load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM load_class = AutoModelForSeq2SeqLM
......
...@@ -18,7 +18,6 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl ...@@ -18,7 +18,6 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl
from ...extras import logging from ...extras import logging
from ...extras.constants import AttentionFunction from ...extras.constants import AttentionFunction
from ...extras.misc import check_version
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -36,8 +35,6 @@ def configure_attn_implementation( ...@@ -36,8 +35,6 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "gemma2" and is_trainable: if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3")
if model_args.flash_attn != AttentionFunction.FA2: if model_args.flash_attn != AttentionFunction.FA2:
logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.") logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = AttentionFunction.FA2 model_args.flash_attn = AttentionFunction.FA2
...@@ -72,6 +69,9 @@ def configure_attn_implementation( ...@@ -72,6 +69,9 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "internlm2": # special case for custom models if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation) setattr(config, "attn_implementation", requested_attn_implementation)
elif getattr(config, "model_type", None) == "kimi_vl":
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
else: else:
setattr(config, "_attn_implementation", requested_attn_implementation) setattr(config, "_attn_implementation", requested_attn_implementation)
......
...@@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable: ...@@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
) -> "torch.Tensor": ) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True) saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad(): with torch.no_grad():
output = forward_function(hidden_states, *args) outputs = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states) ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function ctx.forward_function = forward_function
ctx.args = args ctx.args = args
return output return outputs
@staticmethod @staticmethod
@torch.cuda.amp.custom_bwd @torch.cuda.amp.custom_bwd
...@@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable: ...@@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True) hidden_states.requires_grad_(True)
with torch.enable_grad(): with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args) outputs = ctx.forward_function(hidden_states, *ctx.args)
output = outputs[0] if isinstance(outputs, tuple) else outputs
torch.autograd.backward(output, grad_output) torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args) return (None, hidden_states.grad) + (None,) * len(ctx.args)
......
...@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional ...@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers import transformers
from transformers.models.llama.modeling_llama import Cache, apply_rotary_pos_emb, repeat_kv
from ...extras import logging from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
...@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than ...@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than
if not is_transformers_version_greater_than("4.48.0"): if not is_transformers_version_greater_than("4.48.0"):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward( ...@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward(
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: torch.Tensor = _flash_attention_forward( attn_output: torch.Tensor = _flash_attention_forward(
query_states, query_states,
key_states, key_states,
...@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward( ...@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward(
use_top_left_mask=self._flash_attn_uses_top_left_mask, use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal, is_causal=self.is_causal,
) )
else:
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
...@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward( ...@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
check_version("transformers>=4.41.2,<4.48.0") check_version("transformers>=4.45.0,<4.48.0", mandatory=True)
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment