Commit c7d1b209 authored by chenych's avatar chenych
Browse files

Update 0429

parent c8d12c06
......@@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float:
def main():
client = OpenAI(
api_key="{}".format(os.environ.get("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
api_key="{}".format(os.getenv("API_KEY", "0")),
base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
)
tools = [
{
......
......@@ -32,7 +32,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
baichuan2_state_dict: dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu", weights_only=True)
baichuan2_state_dict.update(shard_weight)
llama_state_dict: dict[str, torch.Tensor] = OrderedDict()
......
......@@ -17,23 +17,8 @@ r"""Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.5.0
accelerate>=0.34.0,<=1.6.0
peft>=0.14.0,<=0.15.1
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<4.48.0
packing:
transformers>=4.43.0
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
......
......@@ -25,7 +25,6 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
......@@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine):
inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
......
......@@ -16,17 +16,7 @@ import os
import subprocess
import sys
from copy import deepcopy
from enum import Enum, unique
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
from functools import partial
USAGE = (
......@@ -44,7 +34,21 @@ USAGE = (
+ "-" * 70
)
WELCOME = (
def main():
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
......@@ -54,40 +58,24 @@ WELCOME = (
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
logger = logging.get_logger(__name__)
@unique
class Command(str, Enum):
API = "api"
CHAT = "chat"
ENV = "env"
EVAL = "eval"
EXPORT = "export"
TRAIN = "train"
WEBDEMO = "webchat"
WEBUI = "webui"
VER = "version"
HELP = "help"
)
COMMAND_MAP = {
"api": run_api,
"chat": run_chat,
"env": print_env,
"eval": run_eval,
"export": export_model,
"train": run_exp,
"webchat": run_web_demo,
"webui": run_web_ui,
"version": partial(print, WELCOME),
"help": partial(print, USAGE),
}
def main():
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
if command == Command.API:
run_api()
elif command == Command.CHAT:
run_chat()
elif command == Command.ENV:
print_env()
elif command == Command.EVAL:
run_eval()
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()):
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
......@@ -123,19 +111,14 @@ def main():
check=True,
)
sys.exit(process.returncode)
else:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI:
run_web_ui()
elif command == Command.VER:
print(WELCOME)
elif command == Command.HELP:
print(USAGE)
elif command in COMMAND_MAP:
COMMAND_MAP[command]()
else:
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__":
from multiprocessing import freeze_support
freeze_support()
main()
......@@ -176,7 +176,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"],
"attention_mask": (features["attention_mask"] >= 1).float(),
}
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
......
This diff is collapsed.
......@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from typing_extensions import override
from ..extras import logging
from ..extras.misc import check_version
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
......@@ -61,7 +61,7 @@ class Template:
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
......@@ -77,7 +77,7 @@ class Template:
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
......@@ -111,12 +111,18 @@ class Template:
return token_ids
def _remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
remove_thought: bool,
) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids.
......@@ -134,14 +140,18 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER:
elements += self.format_user.apply(content=content, idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
......@@ -318,6 +328,7 @@ class Llama2Template(Template):
messages: list[dict[str, str]],
system: str,
tools: str,
remove_thought: bool,
) -> list[list[int]]:
system = system or self.default_system
encoded_messages = []
......@@ -331,14 +342,18 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER:
elements += self.format_user.apply(content=system_text + content)
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
......@@ -477,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system):
default_system = find_diff(user_slot_empty_system, user_slot)
......@@ -518,9 +534,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template = TEMPLATES[data_args.template]
if template.mm_plugin.__class__.__name__ != "BasePlugin":
check_version("transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
......@@ -871,6 +884,18 @@ register_template(
)
register_template(
name="granite3_vision",
format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
......@@ -923,6 +948,20 @@ register_template(
)
register_template(
name="intern_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
),
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="intern_vl", image_token="<image>", video_token="<video>"),
)
register_template(
name="kimi_vl",
format_user=StringFormatter(
......@@ -1389,6 +1428,21 @@ register_template(
)
# copied from qwen template
register_template(
name="qwen3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
)
# copied from chatml template
register_template(
name="qwen2_audio",
......
......@@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
AUDIO_PLACEHOLDER = os.environ.get("AUDIO_PLACEHOLDER", "<audio>")
AUDIO_PLACEHOLDER = os.getenv("AUDIO_PLACEHOLDER", "<audio>")
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
......@@ -50,7 +50,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
IMAGE_PLACEHOLDER = os.getenv("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"}
......@@ -89,7 +89,7 @@ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SWANLAB_CONFIG = "swanlab_public_config.json"
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
VIDEO_PLACEHOLDER = os.getenv("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin"
......@@ -838,11 +838,46 @@ register_model_group(
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-instruct",
},
"Granite-3.2-2B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.2-2b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.2-2b-instruct",
},
"Granite-3.2-8B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.2-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.2-8b-instruct",
},
"Granite-3.3-2B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-2b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-2b-base",
},
"Granite-3.3-8B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-8b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-8b-base",
},
"Granite-3.3-2B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-2b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-2b-instruct",
},
"Granite-3.3-8B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-8b-instruct",
},
},
template="granite3",
)
register_model_group(
models={
"Granite-3.2-1B-A400M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
},
},
template="granite3_vision",
)
register_model_group(
models={
"Hunyuan-7B-Instruct": {
......@@ -965,6 +1000,46 @@ register_model_group(
)
register_model_group(
models={
"InternVL2.5-2B-MPO": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL2_5-2B-MPO-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL2_5-2B-MPO-hf",
},
"InternVL2.5-8B-MPO": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL2_5-8B-MPO-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL2_5-8B-MPO-hf",
},
"InternVL3-1B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-1B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-1B-hf",
},
"InternVL3-2B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-2B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-2B-hf",
},
"InternVL3-8B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-8B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-8B-hf",
},
"InternVL3-14B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-14B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-14B-hf",
},
"InternVL3-38B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-38B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-38B-hf",
},
"InternVL3-78B-hf": {
DownloadSource.DEFAULT: "OpenGVLab/InternVL3-78B-hf",
DownloadSource.MODELSCOPE: "OpenGVLab/InternVL3-78B-hf",
},
},
template="intern_vl",
multimodal=True,
)
register_model_group(
models={
"Jamba-v0.1": {
......@@ -2328,6 +2403,69 @@ register_model_group(
)
register_model_group(
models={
"Qwen3-0.6B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-Base",
},
"Qwen3-1.7B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-Base",
},
"Qwen3-4B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-Base",
},
"Qwen3-8B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-Base",
},
"Qwen3-14B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-Base",
},
"Qwen3-30B-A3B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-Base",
},
"Qwen3-0.6B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B",
},
"Qwen3-1.7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B",
},
"Qwen3-4B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B",
},
"Qwen3-8B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B",
},
"Qwen3-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B",
},
"Qwen3-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B",
},
"Qwen3-30B-A3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B",
},
"Qwen3-235B-A22B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
},
},
template="qwen3",
)
register_model_group(
models={
"Qwen2-Audio-7B": {
......
......@@ -79,7 +79,7 @@ class _Logger(logging.Logger):
def _get_default_logging_level() -> "logging._Level":
r"""Return the default logging level."""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
......
......@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.5.0")
check_version("accelerate>=0.34.0,<=1.6.0")
check_version("peft>=0.14.0,<=0.15.1")
......@@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
def get_current_device() -> "torch.device":
r"""Get the current available device."""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_mps_available():
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
else:
device = "cpu"
......@@ -155,11 +155,13 @@ def get_current_device() -> "torch.device":
def get_device_count() -> int:
r"""Get the number of available GPU or NPU devices."""
r"""Get the number of available devices."""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_mps_available():
return torch.mps.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
else:
......@@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList":
def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_xpu_available():
if is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_mps_available():
return torch.mps.current_allocated_memory(), -1
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
......@@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
return torch.float32
def is_gpu_or_npu_available() -> bool:
r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
def is_accelerator_available() -> bool:
r"""Check if the accelerator is available."""
return (
is_torch_xpu_available() or is_torch_npu_available() or is_torch_mps_available() or is_torch_cuda_available()
)
def is_env_enabled(env_var: str, default: str = "0") -> bool:
......@@ -229,7 +235,7 @@ def skip_check_imports() -> None:
def torch_gc() -> None:
r"""Collect GPU or NPU memory."""
r"""Collect the device memory."""
gc.collect()
if is_torch_xpu_available():
torch.xpu.empty_cache()
......@@ -280,7 +286,7 @@ def use_ray() -> bool:
def find_available_port() -> int:
"""Find an available port on the local machine."""
r"""Find an available port on the local machine."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
......@@ -288,8 +294,8 @@ def find_available_port() -> int:
return port
def fix_proxy(ipv6_enabled: bool) -> None:
"""Fix proxy settings for gradio ui."""
def fix_proxy(ipv6_enabled: bool = False) -> None:
r"""Fix proxy settings for gradio ui."""
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
if ipv6_enabled:
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
......
......@@ -411,6 +411,10 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
)
use_muon: bool = field(
default=False,
metadata={"help": "Whether or not to use the Muon optimizer."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
......@@ -431,6 +435,10 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
early_stopping_steps: Optional[int] = field(
default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
......
......@@ -65,7 +65,13 @@ class BaseModelArguments:
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
new_special_tokens: Optional[str] = field(
add_tokens: Optional[str] = field(
default=None,
metadata={
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
add_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
......@@ -176,8 +182,11 @@ class BaseModelArguments:
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
if self.add_tokens is not None: # support multiple tokens
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
if self.add_special_tokens is not None: # support multiple special tokens
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
@dataclass
......@@ -222,6 +231,10 @@ class ProcessorArguments:
default=False,
metadata={"help": "Use pan and scan to process image for gemma3."},
)
crop_to_patches: bool = field(
default=False,
metadata={"help": "Whether to crop the image to patches for internvl."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
......
......@@ -24,6 +24,7 @@ from typing import Any, Optional, Union
import torch
import transformers
import yaml
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
......@@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if args is not None:
return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return json.loads(Path(sys.argv[1]).absolute().read_text())
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = json.loads(Path(sys.argv[1]).absolute().read_text())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else:
return sys.argv[1:]
......@@ -91,6 +96,14 @@ def _set_transformers_logging() -> None:
transformers.utils.logging.enable_explicit_format()
def _set_env_vars() -> None:
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _verify_model_args(
model_args: "ModelArguments",
data_args: "DataArguments",
......@@ -279,12 +292,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
......@@ -321,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
if finetuning_args.finetuning_type == "lora":
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args.label_names = training_args.label_names or ["labels"]
if (
training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.")
training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
......@@ -407,6 +429,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
......@@ -428,9 +451,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
_set_transformers_logging()
# Check arguments
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
......
......@@ -34,6 +34,10 @@ class RayArguments:
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
......@@ -55,6 +59,17 @@ class RayArguments:
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
)
import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3":
self.ray_storage_filesystem = fs.S3FileSystem()
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
self.ray_storage_filesystem = fs.GcsFileSystem()
@dataclass
......
......@@ -23,7 +23,7 @@ from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.visual import get_forbidden_modules, patch_target_modules
from .model_utils.visual import COMPOSITE_MODELS, get_forbidden_modules, patch_target_modules
if TYPE_CHECKING:
......@@ -100,7 +100,7 @@ def _setup_freeze_tuning(
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
non_hidden_modules.add(name.split(".")[-2]) # remove weight/bias
trainable_layers = []
for module_name in finetuning_args.freeze_trainable_modules:
......@@ -121,6 +121,10 @@ def _setup_freeze_tuning(
trainable_layers.append(module_name)
model_type = getattr(model.config, "model_type", None)
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
......@@ -204,7 +208,7 @@ def _setup_lora_tuning(
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
and getattr(model, "quantization_method", None) != QuantizationMethod.BNB
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
......
......@@ -19,7 +19,6 @@ import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform,
AutoModelForVision2Seq,
......@@ -30,6 +29,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import is_transformers_version_greater_than
from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
......@@ -39,6 +39,10 @@ from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
if is_transformers_version_greater_than("4.46.0"):
from transformers import AutoModelForImageTextToText
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
......@@ -97,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, tokenizer, model_args)
except Exception as e:
logger.debug(f"Failed to load processor: {e}.")
logger.info_rank0(f"Failed to load processor: {e}.")
processor = None
# Avoid load tokenizer, see:
......@@ -145,7 +149,10 @@ def load_model(
else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
elif (
is_transformers_version_greater_than("4.46.0")
and type(config) in AutoModelForImageTextToText._model_mapping.keys()
): # image-text
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
......
......@@ -18,7 +18,6 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl
from ...extras import logging
from ...extras.constants import AttentionFunction
from ...extras.misc import check_version
if TYPE_CHECKING:
......@@ -36,8 +35,6 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():
check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3")
if model_args.flash_attn != AttentionFunction.FA2:
logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = AttentionFunction.FA2
......@@ -72,6 +69,9 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
elif getattr(config, "model_type", None) == "kimi_vl":
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
......
......@@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
outputs = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
return outputs
@staticmethod
@torch.cuda.amp.custom_bwd
......@@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
outputs = ctx.forward_function(hidden_states, *ctx.args)
output = outputs[0] if isinstance(outputs, tuple) else outputs
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)
......
......@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import Cache, apply_rotary_pos_emb, repeat_kv
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
......@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than
if not is_transformers_version_greater_than("4.48.0"):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
if TYPE_CHECKING:
......@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: torch.Tensor = _flash_attention_forward(
query_states,
key_states,
......@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward(
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
else:
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
......@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
check_version("transformers>=4.41.2,<4.48.0")
check_version("transformers>=4.45.0,<4.48.0", mandatory=True)
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment