Commit 2778a3d0 authored by luopl's avatar luopl
Browse files

updata to v0.9.1_stable

parent e92143e3
......@@ -15,7 +15,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ...extras import logging
from ..data_utils import Role
from .processor_utils import infer_seqlen
......@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_unsupervised_example(
......@@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_unsupervised_example(
......
......@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras.logging import get_logger
from ..extras import logging
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
......@@ -32,7 +32,7 @@ if TYPE_CHECKING:
from .mm_plugin import BasePlugin
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
@dataclass
......@@ -147,7 +147,7 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids
......@@ -275,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token))
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str:
......@@ -356,22 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder
else:
template = TEMPLATES.get(data_args.template, None)
if template is None:
raise ValueError("Template {} does not exist.".format(data_args.template))
raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format))
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
......@@ -389,21 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
if template.replace_jinja_template:
if tokenizer.chat_template is None or template.replace_jinja_template:
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError:
logger.info("Cannot add this chat template to tokenizer.")
except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
return template
......@@ -692,6 +691,14 @@ _register_template(
)
_register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
format_system=StringFormatter(slots=["<unk>{{content}}"]),
efficient_eos=True,
)
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
......@@ -755,6 +762,33 @@ _register_template(
)
_register_template(
name="mllama",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
......@@ -904,6 +938,19 @@ _register_template(
)
_register_template(
name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are OpenCoder, created by OpenCoder Team.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
......@@ -935,6 +982,25 @@ _register_template(
)
_register_template(
name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
......
......@@ -177,6 +177,6 @@ TOOLS = {
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name))
raise ValueError(f"Tool utils `{name}` not found.")
return tool_utils
......@@ -87,7 +87,7 @@ class Evaluator:
token=self.model_args.hf_hub_token,
)
with open(mapping, "r", encoding="utf-8") as f:
with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
......@@ -139,7 +139,7 @@ class Evaluator:
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join(
[
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
for category_name, category_correct in category_corrects.items()
if len(category_correct)
]
......
......@@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name)
assert eval_template is not None, f"Template {name} does not exist."
return eval_template
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
......@@ -47,7 +48,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>"
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"}
......@@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>"
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin"
......@@ -107,6 +108,7 @@ VISION_MODELS = set()
class DownloadSource(str, Enum):
DEFAULT = "hf"
MODELSCOPE = "ms"
OPENMIND = "om"
def register_model_group(
......@@ -163,14 +165,17 @@ register_model_group(
"Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
},
"Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
},
"Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
},
},
template="baichuan2",
......@@ -555,10 +560,12 @@ register_model_group(
"Gemma-2-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-2b-it",
},
"Gemma-2-9B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-9b-it",
},
"Gemma-2-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
......@@ -578,6 +585,7 @@ register_model_group(
"GLM-4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
},
"GLM-4-9B-1M-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
......@@ -588,6 +596,33 @@ register_model_group(
)
register_model_group(
models={
"Index-1.9B-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
},
"Index-1.9B-Character-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
},
"Index-1.9B-Base": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
},
"Index-1.9B-Base-Pure": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
},
"Index-1.9B-Chat-32K": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
},
},
template="index",
)
register_model_group(
models={
"InternLM-7B": {
......@@ -632,6 +667,7 @@ register_model_group(
"InternLM2.5-1.8B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b",
},
"InternLM2.5-7B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
......@@ -640,22 +676,27 @@ register_model_group(
"InternLM2.5-20B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b",
},
"InternLM2.5-1.8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b-chat",
},
"InternLM2.5-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat",
},
"InternLM2.5-7B-1M-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat-1m",
},
"InternLM2.5-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
},
},
template="intern2",
......@@ -756,6 +797,7 @@ register_model_group(
"Llama-3-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Llama3-Chinese-8B-Instruct",
},
"Llama-3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
......@@ -813,6 +855,22 @@ register_model_group(
)
register_model_group(
models={
"Llama-3.2-11B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
},
"Llama-3.2-90B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
},
},
template="mllama",
vision=True,
)
register_model_group(
models={
"LLaVA-1.5-7B-Chat": {
......@@ -960,6 +1018,7 @@ register_model_group(
"MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
},
},
template="cpm3",
......@@ -1062,6 +1121,29 @@ register_model_group(
)
register_model_group(
models={
"OpenCoder-1.5B-Base": {
DownloadSource.DEFAULT: "infly/OpenCoder-1.5B-Base",
DownloadSource.MODELSCOPE: "infly/OpenCoder-1.5B-Base",
},
"OpenCoder-8B-Base": {
DownloadSource.DEFAULT: "infly/OpenCoder-8B-Base",
DownloadSource.MODELSCOPE: "infly/OpenCoder-8B-Base",
},
"OpenCoder-1.5B-Instruct": {
DownloadSource.DEFAULT: "infly/OpenCoder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "infly/OpenCoder-1.5B-Instruct",
},
"OpenCoder-8B-Instruct": {
DownloadSource.DEFAULT: "infly/OpenCoder-8B-Instruct",
DownloadSource.MODELSCOPE: "infly/OpenCoder-8B-Instruct",
},
},
template="opencoder",
)
register_model_group(
models={
"Orion-14B-Base": {
......@@ -1141,14 +1223,6 @@ register_model_group(
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
},
"Phi-3-7B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
},
"Phi-3-7B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
},
"Phi-3-14B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
......@@ -1162,6 +1236,33 @@ register_model_group(
)
register_model_group(
models={
"Phi-3-7B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
},
"Phi-3-7B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
},
},
template="phi_small",
)
register_model_group(
models={
"Pixtral-12B-Chat": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="pixtral",
vision=True,
)
register_model_group(
models={
"Qwen-1.8B": {
......@@ -1409,14 +1510,17 @@ register_model_group(
"Qwen2-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
},
"Qwen2-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
},
"Qwen2-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
},
"Qwen2-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
......@@ -1649,22 +1753,54 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-AWQ",
},
"Qwen2.5-Coder-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-0.5B",
},
"Qwen2.5-Coder-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B",
},
"Qwen2.5-Coder-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-3B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-3B",
},
"Qwen2.5-Coder-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B",
},
"Qwen2.5-Coder-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-14B",
},
"Qwen2.5-Coder-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-32B",
},
"Qwen2.5-Coder-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-0.5B-Instruct",
},
"Qwen2.5-Coder-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B-Instruct",
},
"Qwen2.5-Coder-3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-3B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-3B-Instruct",
},
"Qwen2.5-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B-Instruct",
},
"Qwen2.5-Coder-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-14B-Instruct",
},
"Qwen2.5-Coder-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-32B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-32B-Instruct",
},
"Qwen2.5-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-1.5B",
......@@ -1699,10 +1835,12 @@ register_model_group(
"Qwen2-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
},
"Qwen2-VL-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
},
"Qwen2-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
......@@ -1801,10 +1939,12 @@ register_model_group(
"TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
},
"TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
},
"TeleChat-12B-v2-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
......@@ -2023,6 +2163,7 @@ register_model_group(
"Yi-1.5-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Yi-1.5-6B-Chat",
},
"Yi-1.5-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
......
......@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.1.dev0"
VERSION = "0.9.1"
def print_env() -> None:
......@@ -72,4 +72,4 @@ def print_env() -> None:
except Exception:
pass
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
......@@ -20,6 +20,7 @@ import os
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Optional
from .constants import RUNNING_LOG
......@@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
def __init__(self, output_dir: str) -> None:
super().__init__()
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
self._formatter = logging.Formatter(
fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
self.setLevel(logging.INFO)
self.setFormatter(formatter)
os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log):
......@@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
if record.name == "httpx":
return
log_entry = self.format(record)
log_entry = self._formatter.format(record)
self.thread_pool.submit(self._write_log, log_entry)
def close(self) -> None:
......@@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
return super().close()
class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
"""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def _get_default_logging_level() -> "logging._Level":
r"""
Returns the default logging level.
......@@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError("Unknown logging level: {}.".format(env_level_str))
raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level
......@@ -84,7 +99,7 @@ def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "logging.Logger":
def _get_library_root_logger() -> "_Logger":
return logging.getLogger(_get_library_name())
......@@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
global _default_handler
with _thread_lock:
if _default_handler:
if _default_handler: # already configured
return
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
......@@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "logging.Logger":
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
......@@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
_configure_library_root_logger()
return logging.getLogger(name)
def add_handler(handler: "logging.Handler") -> None:
r"""
Adds a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""
Removes a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.info(*args, **kwargs)
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once
......@@ -20,6 +20,7 @@ import os
from typing import TYPE_CHECKING, Tuple, Union
import torch
import torch.distributed as dist
import transformers.dynamic_module_utils
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports
......@@ -32,7 +33,7 @@ from transformers.utils import (
)
from transformers.utils.versions import require_version
from .logging import get_logger
from . import logging
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
......@@ -48,7 +49,7 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class AverageMeter:
......@@ -76,12 +77,12 @@ def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
......@@ -231,18 +232,43 @@ def torch_gc() -> None:
torch.cuda.empty_cache()
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
return model_args.model_name_or_path
try:
from modelscope import snapshot_download
if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
return snapshot_download(
model_args.model_name_or_path,
revision=revision,
cache_dir=model_args.cache_dir,
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
model_args.model_name_or_path,
revision=model_args.model_revision,
cache_dir=model_args.cache_dir,
)
def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
r"""
calculate effective tokens.
"""
result = effective_token_num * epoch / train_runtime
return result / dist.get_world_size() if dist.is_initialized() else result
......@@ -75,8 +75,13 @@ def is_starlette_available():
@lru_cache
def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0")
def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
@lru_cache
def is_transformers_version_equal_to_4_46():
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
def is_uvicorn_available():
......
......@@ -19,7 +19,7 @@ from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME
from .logging import get_logger
from . import logging
from .packages import is_matplotlib_available
......@@ -28,7 +28,7 @@ if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]:
......@@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
Plots loss curves and saves the image.
"""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f)
for key in keys:
......@@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
metrics.append(data["log_history"][i][key])
if len(metrics) == 0:
logger.warning(f"No metric {key} to plot.")
logger.warning_rank0(f"No metric {key} to plot.")
continue
plt.figure()
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title("training {} of {}".format(key, save_dictionary))
plt.title(f"training {key} of {save_dictionary}")
plt.xlabel("step")
plt.ylabel(key)
plt.legend()
......
......@@ -41,8 +41,12 @@ class DataArguments:
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
image_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
)
cutoff_len: int = field(
default=1024,
default=2048,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
train_on_prompt: bool = field(
......@@ -111,7 +115,13 @@ class DataArguments:
)
tokenized_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the tokenized datasets."},
metadata={
"help": (
"Path to save or load the tokenized datasets. "
"If tokenized_path not exists, it will save the tokenized datasets. "
"If tokenized_path exists, it will load the tokenized datasets."
)
},
)
def __post_init__(self):
......@@ -123,6 +133,9 @@ class DataArguments:
self.dataset = split_arg(self.dataset)
self.eval_dataset = split_arg(self.eval_dataset)
if self.image_dir is None:
self.image_dir = self.dataset_dir
if self.dataset is None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
......
......@@ -346,6 +346,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
)
include_effective_tokens_per_second: bool = field(
default=False,
metadata={"help": "Whether or not to compute effective tokens per second."},
)
def __post_init__(self):
def split_arg(arg):
......
......@@ -15,10 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
......@@ -57,12 +59,12 @@ class ProcessorArguments:
"""
image_resolution: int = field(
default=512,
metadata={"help": "Keeps the height or width of image below this resolution."},
default=512 * 512,
metadata={"help": "Keeps the number of pixels of image below this resolution."},
)
video_resolution: int = field(
default=128,
metadata={"help": "Keeps the height or width of video below this resolution."},
default=128 * 128,
metadata={"help": "Keeps the number of pixels of video below this resolution."},
)
video_fps: float = field(
default=2.0,
......@@ -125,7 +127,7 @@ class VllmArguments:
"""
vllm_maxlen: int = field(
default=2048,
default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
......@@ -140,6 +142,10 @@ class VllmArguments:
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
vllm_config: Optional[Union[dict, str]] = field(
default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
)
@dataclass
......@@ -267,6 +273,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
om_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
print_param_status: bool = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
......@@ -308,6 +318,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
@classmethod
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
init_args, lazy_args = {}, {}
......
......@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
from typing import Any, Dict, Optional, Tuple
......@@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
......@@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
from .model_args import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
check_dependencies()
......@@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if unknown_args:
print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,)
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
transformers.utils.logging.set_verbosity(log_level)
def _set_transformers_logging() -> None:
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
......@@ -104,7 +103,7 @@ def _verify_model_args(
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
......@@ -123,7 +122,7 @@ def _check_extra_dependencies(
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<=0.6.2", "To fix: pip install vllm>=0.4.3,<=0.6.2")
require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
......@@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing:
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True
_verify_model_args(model_args, data_args, finetuning_args)
......@@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
logger.warning_rank0(
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
)
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments
if (
......@@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True
......@@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if (
finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None
):
logger.warning(
logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
......
......@@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
......@@ -33,7 +33,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _setup_full_tuning(
......@@ -45,7 +45,7 @@ def _setup_full_tuning(
if not is_trainable:
return
logger.info("Fine-tuning method: Full")
logger.info_rank0("Fine-tuning method: Full")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
......@@ -64,7 +64,7 @@ def _setup_freeze_tuning(
if not is_trainable:
return
logger.info("Fine-tuning method: Freeze")
logger.info_rank0("Fine-tuning method: Freeze")
if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config")
else:
......@@ -133,7 +133,7 @@ def _setup_freeze_tuning(
else:
param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning(
......@@ -145,7 +145,7 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
if is_trainable:
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
......@@ -182,7 +182,7 @@ def _setup_lora_tuning(
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
......@@ -190,7 +190,7 @@ def _setup_lora_tuning(
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
......@@ -219,7 +219,7 @@ def _setup_lora_tuning(
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
......@@ -236,11 +236,11 @@ def _setup_lora_tuning(
else:
if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.")
logger.info_rank0("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
......@@ -284,11 +284,11 @@ def init_adapter(
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info("Upcasting trainable params to float32.")
logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full":
......@@ -300,6 +300,6 @@ def init_adapter(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
else:
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
return model
......@@ -18,8 +18,8 @@ import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
......@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class TokenizerModule(TypedDict):
......@@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
Note: including inplace operation of model_args.
"""
skip_check_imports()
model_args.model_name_or_path = try_download_model_from_ms(model_args)
model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
return {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
......@@ -90,17 +90,17 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
except Exception as e:
logger.warning("Processor was not found: {}.".format(e))
logger.debug(f"Processor was not found: {e}.")
processor = None
# Avoid load tokenizer, see:
......@@ -153,8 +153,9 @@ def load_model(
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config)
model = load_class.from_config(config, trust_remote_code=True)
else:
model = load_class.from_pretrained(**init_kwargs)
......@@ -179,7 +180,7 @@ def load_model(
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
if not is_trainable:
model.requires_grad_(False)
......@@ -197,9 +198,9 @@ def load_model(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
param_stats = "all params: {:,}".format(all_param)
param_stats = f"all params: {all_param:,}"
logger.info(param_stats)
logger.info_rank0(param_stats)
if model_args.print_param_status:
for name, param in model.named_parameters():
......
......@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
......@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def configure_attn_implementation(
......@@ -38,13 +38,15 @@ def configure_attn_implementation(
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
if model_args.flash_attn != "fa2":
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else:
logger.warning("FlashAttention-2 is not installed, use eager attention.")
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa":
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
logger.warning_rank0(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if model_args.flash_attn == "auto":
return
......@@ -54,18 +56,18 @@ def configure_attn_implementation(
elif model_args.flash_attn == "sdpa":
if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.")
logger.warning_rank0("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
......@@ -80,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
logger.info_rank0("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla attention implementation.")
logger.info_rank0("Using vanilla attention implementation.")
......@@ -19,14 +19,14 @@
# limitations under the License.
import inspect
from functools import partial, wraps
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch
from ...extras import logging
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
......@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
......@@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func)
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
......@@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func
......@@ -111,7 +108,7 @@ def _gradient_checkpointing_enable(
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
......@@ -125,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
......@@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
(3) add the upcasting of the lm_head in fp32
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
logger.info_rank0("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
logger.warning_rank0("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
......@@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
logger.info_rank0("Gradient checkpointing enabled.")
if model_args.upcast_lmhead_output:
output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
logger.info("Upcasting lm_head outputs in float32.")
logger.info_rank0("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook)
......@@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
......@@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment