Commit 2778a3d0 authored by luopl's avatar luopl
Browse files

updata to v0.9.1_stable

parent e92143e3
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras import logging
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import infer_seqlen from .processor_utils import infer_seqlen
...@@ -28,7 +28,7 @@ if TYPE_CHECKING: ...@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template from ..template import Template
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _encode_unsupervised_example( def _encode_unsupervised_example(
...@@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset( ...@@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])): for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1: if len(examples["_prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue continue
input_ids, labels = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(
......
...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union ...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from typing_extensions import override from typing_extensions import override
from ..extras.logging import get_logger from ..extras import logging
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin from .mm_plugin import get_mm_plugin
...@@ -32,7 +32,7 @@ if TYPE_CHECKING: ...@@ -32,7 +32,7 @@ if TYPE_CHECKING:
from .mm_plugin import BasePlugin from .mm_plugin import BasePlugin
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass
...@@ -147,7 +147,7 @@ class Template: ...@@ -147,7 +147,7 @@ class Template:
elif "eos_token" in elem and tokenizer.eos_token_id is not None: elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id] token_ids += [tokenizer.eos_token_id]
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids return token_ids
...@@ -275,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) ...@@ -275,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added: if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else: else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str: def _jinja_escape(content: str) -> str:
...@@ -356,22 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ...@@ -356,22 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
r""" r"""
Gets chat template and fixes the tokenizer. Gets chat template and fixes the tokenizer.
""" """
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
if data_args.template is None: if data_args.template is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
template = TEMPLATES.get(data_args.template, None) template = TEMPLATES.get(data_args.template, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(data_args.template)) raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None: if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(data_args.tool_format)) logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
...@@ -389,21 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ...@@ -389,21 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words: if stop_words:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
) )
logger.info("Add {} to stop words.".format(",".join(stop_words))) logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
if template.replace_jinja_template: if tokenizer.chat_template is None or template.replace_jinja_template:
try: try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer) tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError: except ValueError as e:
logger.info("Cannot add this chat template to tokenizer.") logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
return template return template
...@@ -692,6 +691,14 @@ _register_template( ...@@ -692,6 +691,14 @@ _register_template(
) )
_register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
format_system=StringFormatter(slots=["<unk>{{content}}"]),
efficient_eos=True,
)
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
...@@ -755,6 +762,33 @@ _register_template( ...@@ -755,6 +762,33 @@ _register_template(
) )
_register_template(
name="mllama",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
)
_register_template( _register_template(
name="llava", name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
...@@ -904,6 +938,19 @@ _register_template( ...@@ -904,6 +938,19 @@ _register_template(
) )
_register_template(
name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are OpenCoder, created by OpenCoder Team.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
)
_register_template( _register_template(
name="orion", name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
...@@ -935,6 +982,25 @@ _register_template( ...@@ -935,6 +982,25 @@ _register_template(
) )
_register_template(
name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
_register_template( _register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
......
...@@ -177,6 +177,6 @@ TOOLS = { ...@@ -177,6 +177,6 @@ TOOLS = {
def get_tool_utils(name: str) -> "ToolUtils": def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None) tool_utils = TOOLS.get(name, None)
if tool_utils is None: if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name)) raise ValueError(f"Tool utils `{name}` not found.")
return tool_utils return tool_utils
...@@ -87,7 +87,7 @@ class Evaluator: ...@@ -87,7 +87,7 @@ class Evaluator:
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
) )
with open(mapping, "r", encoding="utf-8") as f: with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS} category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
...@@ -139,7 +139,7 @@ class Evaluator: ...@@ -139,7 +139,7 @@ class Evaluator:
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
for category_name, category_correct in category_corrects.items() for category_name, category_correct in category_corrects.items()
if len(category_correct) if len(category_correct)
] ]
......
...@@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) -> ...@@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
def get_eval_template(name: str) -> "EvalTemplate": def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None) eval_template = eval_templates.get(name, None)
assert eval_template is not None, "Template {} does not exist.".format(name) assert eval_template is not None, f"Template {name} does not exist."
return eval_template return eval_template
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
...@@ -47,7 +48,7 @@ FILEEXT2TYPE = { ...@@ -47,7 +48,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>" IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
...@@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { ...@@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>" VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
...@@ -107,6 +108,7 @@ VISION_MODELS = set() ...@@ -107,6 +108,7 @@ VISION_MODELS = set()
class DownloadSource(str, Enum): class DownloadSource(str, Enum):
DEFAULT = "hf" DEFAULT = "hf"
MODELSCOPE = "ms" MODELSCOPE = "ms"
OPENMIND = "om"
def register_model_group( def register_model_group(
...@@ -163,14 +165,17 @@ register_model_group( ...@@ -163,14 +165,17 @@ register_model_group(
"Baichuan2-13B-Base": { "Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
}, },
"Baichuan2-7B-Chat": { "Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
}, },
"Baichuan2-13B-Chat": { "Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
}, },
}, },
template="baichuan2", template="baichuan2",
...@@ -555,10 +560,12 @@ register_model_group( ...@@ -555,10 +560,12 @@ register_model_group(
"Gemma-2-2B-Instruct": { "Gemma-2-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it", DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-2b-it",
}, },
"Gemma-2-9B-Instruct": { "Gemma-2-9B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it", DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-9b-it",
}, },
"Gemma-2-27B-Instruct": { "Gemma-2-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-27b-it", DownloadSource.DEFAULT: "google/gemma-2-27b-it",
...@@ -578,6 +585,7 @@ register_model_group( ...@@ -578,6 +585,7 @@ register_model_group(
"GLM-4-9B-Chat": { "GLM-4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat", DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat", DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
}, },
"GLM-4-9B-1M-Chat": { "GLM-4-9B-1M-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m", DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
...@@ -588,6 +596,33 @@ register_model_group( ...@@ -588,6 +596,33 @@ register_model_group(
) )
register_model_group(
models={
"Index-1.9B-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
},
"Index-1.9B-Character-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
},
"Index-1.9B-Base": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
},
"Index-1.9B-Base-Pure": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
},
"Index-1.9B-Chat-32K": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
},
},
template="index",
)
register_model_group( register_model_group(
models={ models={
"InternLM-7B": { "InternLM-7B": {
...@@ -632,6 +667,7 @@ register_model_group( ...@@ -632,6 +667,7 @@ register_model_group(
"InternLM2.5-1.8B": { "InternLM2.5-1.8B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b", DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b",
}, },
"InternLM2.5-7B": { "InternLM2.5-7B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b", DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
...@@ -640,22 +676,27 @@ register_model_group( ...@@ -640,22 +676,27 @@ register_model_group(
"InternLM2.5-20B": { "InternLM2.5-20B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b", DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b",
}, },
"InternLM2.5-1.8B-Chat": { "InternLM2.5-1.8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b-chat",
}, },
"InternLM2.5-7B-Chat": { "InternLM2.5-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat",
}, },
"InternLM2.5-7B-1M-Chat": { "InternLM2.5-7B-1M-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat-1m",
}, },
"InternLM2.5-20B-Chat": { "InternLM2.5-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
}, },
}, },
template="intern2", template="intern2",
...@@ -756,6 +797,7 @@ register_model_group( ...@@ -756,6 +797,7 @@ register_model_group(
"Llama-3-8B-Chinese-Chat": { "Llama-3-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat", DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat", DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Llama3-Chinese-8B-Instruct",
}, },
"Llama-3-70B-Chinese-Chat": { "Llama-3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat", DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
...@@ -813,6 +855,22 @@ register_model_group( ...@@ -813,6 +855,22 @@ register_model_group(
) )
register_model_group(
models={
"Llama-3.2-11B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
},
"Llama-3.2-90B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
},
},
template="mllama",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"LLaVA-1.5-7B-Chat": { "LLaVA-1.5-7B-Chat": {
...@@ -960,6 +1018,7 @@ register_model_group( ...@@ -960,6 +1018,7 @@ register_model_group(
"MiniCPM3-4B-Chat": { "MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B", DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
}, },
}, },
template="cpm3", template="cpm3",
...@@ -1062,6 +1121,29 @@ register_model_group( ...@@ -1062,6 +1121,29 @@ register_model_group(
) )
register_model_group(
models={
"OpenCoder-1.5B-Base": {
DownloadSource.DEFAULT: "infly/OpenCoder-1.5B-Base",
DownloadSource.MODELSCOPE: "infly/OpenCoder-1.5B-Base",
},
"OpenCoder-8B-Base": {
DownloadSource.DEFAULT: "infly/OpenCoder-8B-Base",
DownloadSource.MODELSCOPE: "infly/OpenCoder-8B-Base",
},
"OpenCoder-1.5B-Instruct": {
DownloadSource.DEFAULT: "infly/OpenCoder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "infly/OpenCoder-1.5B-Instruct",
},
"OpenCoder-8B-Instruct": {
DownloadSource.DEFAULT: "infly/OpenCoder-8B-Instruct",
DownloadSource.MODELSCOPE: "infly/OpenCoder-8B-Instruct",
},
},
template="opencoder",
)
register_model_group( register_model_group(
models={ models={
"Orion-14B-Base": { "Orion-14B-Base": {
...@@ -1141,14 +1223,6 @@ register_model_group( ...@@ -1141,14 +1223,6 @@ register_model_group(
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
}, },
"Phi-3-7B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
},
"Phi-3-7B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
},
"Phi-3-14B-8k-Instruct": { "Phi-3-14B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
...@@ -1162,6 +1236,33 @@ register_model_group( ...@@ -1162,6 +1236,33 @@ register_model_group(
) )
register_model_group(
models={
"Phi-3-7B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
},
"Phi-3-7B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
},
},
template="phi_small",
)
register_model_group(
models={
"Pixtral-12B-Chat": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="pixtral",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"Qwen-1.8B": { "Qwen-1.8B": {
...@@ -1409,14 +1510,17 @@ register_model_group( ...@@ -1409,14 +1510,17 @@ register_model_group(
"Qwen2-0.5B-Instruct": { "Qwen2-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
}, },
"Qwen2-1.5B-Instruct": { "Qwen2-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
}, },
"Qwen2-7B-Instruct": { "Qwen2-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
}, },
"Qwen2-72B-Instruct": { "Qwen2-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
...@@ -1649,22 +1753,54 @@ register_model_group( ...@@ -1649,22 +1753,54 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-AWQ",
}, },
"Qwen2.5-Coder-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-0.5B",
},
"Qwen2.5-Coder-1.5B": { "Qwen2.5-Coder-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B", DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B",
}, },
"Qwen2.5-Coder-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-3B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-3B",
},
"Qwen2.5-Coder-7B": { "Qwen2.5-Coder-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B", DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B",
}, },
"Qwen2.5-Coder-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-14B",
},
"Qwen2.5-Coder-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-32B",
},
"Qwen2.5-Coder-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-0.5B-Instruct",
},
"Qwen2.5-Coder-1.5B-Instruct": { "Qwen2.5-Coder-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B-Instruct",
}, },
"Qwen2.5-Coder-3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-3B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-3B-Instruct",
},
"Qwen2.5-Coder-7B-Instruct": { "Qwen2.5-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B-Instruct",
}, },
"Qwen2.5-Coder-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-14B-Instruct",
},
"Qwen2.5-Coder-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-32B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-32B-Instruct",
},
"Qwen2.5-Math-1.5B": { "Qwen2.5-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-1.5B", DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-1.5B",
...@@ -1699,10 +1835,12 @@ register_model_group( ...@@ -1699,10 +1835,12 @@ register_model_group(
"Qwen2-VL-2B-Instruct": { "Qwen2-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
}, },
"Qwen2-VL-7B-Instruct": { "Qwen2-VL-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
}, },
"Qwen2-VL-72B-Instruct": { "Qwen2-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
...@@ -1801,10 +1939,12 @@ register_model_group( ...@@ -1801,10 +1939,12 @@ register_model_group(
"TeleChat-7B-Chat": { "TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B", DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B", DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
}, },
"TeleChat-12B-Chat": { "TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B", DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B", DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
}, },
"TeleChat-12B-v2-Chat": { "TeleChat-12B-v2-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2", DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
...@@ -2023,6 +2163,7 @@ register_model_group( ...@@ -2023,6 +2163,7 @@ register_model_group(
"Yi-1.5-6B-Chat": { "Yi-1.5-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Yi-1.5-6B-Chat",
}, },
"Yi-1.5-9B-Chat": { "Yi-1.5-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
......
...@@ -26,7 +26,7 @@ import trl ...@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.1.dev0" VERSION = "0.9.1"
def print_env() -> None: def print_env() -> None:
...@@ -72,4 +72,4 @@ def print_env() -> None: ...@@ -72,4 +72,4 @@ def print_env() -> None:
except Exception: except Exception:
pass pass
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import sys import sys
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Optional from typing import Optional
from .constants import RUNNING_LOG from .constants import RUNNING_LOG
...@@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler): ...@@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
def __init__(self, output_dir: str) -> None: def __init__(self, output_dir: str) -> None:
super().__init__() super().__init__()
formatter = logging.Formatter( self._formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
) )
self.setLevel(logging.INFO) self.setLevel(logging.INFO)
self.setFormatter(formatter)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG) self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log): if os.path.exists(self.running_log):
...@@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler): ...@@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
if record.name == "httpx": if record.name == "httpx":
return return
log_entry = self.format(record) log_entry = self._formatter.format(record)
self.thread_pool.submit(self._write_log, log_entry) self.thread_pool.submit(self._write_log, log_entry)
def close(self) -> None: def close(self) -> None:
...@@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler): ...@@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
return super().close() return super().close()
class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
"""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def _get_default_logging_level() -> "logging._Level": def _get_default_logging_level() -> "logging._Level":
r""" r"""
Returns the default logging level. Returns the default logging level.
...@@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level": ...@@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
if env_level_str.upper() in logging._nameToLevel: if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()] return logging._nameToLevel[env_level_str.upper()]
else: else:
raise ValueError("Unknown logging level: {}.".format(env_level_str)) raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level return _default_log_level
...@@ -84,7 +99,7 @@ def _get_library_name() -> str: ...@@ -84,7 +99,7 @@ def _get_library_name() -> str:
return __name__.split(".")[0] return __name__.split(".")[0]
def _get_library_root_logger() -> "logging.Logger": def _get_library_root_logger() -> "_Logger":
return logging.getLogger(_get_library_name()) return logging.getLogger(_get_library_name())
...@@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None: ...@@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
global _default_handler global _default_handler
with _thread_lock: with _thread_lock:
if _default_handler: if _default_handler: # already configured
return return
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )
_default_handler = logging.StreamHandler(sys.stdout) _default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter) _default_handler.setFormatter(formatter)
...@@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None: ...@@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "logging.Logger": def get_logger(name: Optional[str] = None) -> "_Logger":
r""" r"""
Returns a logger with the specified name. It it not supposed to be accessed externally. Returns a logger with the specified name. It it not supposed to be accessed externally.
""" """
...@@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger": ...@@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
_configure_library_root_logger() _configure_library_root_logger()
return logging.getLogger(name) return logging.getLogger(name)
def add_handler(handler: "logging.Handler") -> None:
r"""
Adds a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""
Removes a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.info(*args, **kwargs)
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
from typing import TYPE_CHECKING, Tuple, Union from typing import TYPE_CHECKING, Tuple, Union
import torch import torch
import torch.distributed as dist
import transformers.dynamic_module_utils import transformers.dynamic_module_utils
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports from transformers.dynamic_module_utils import get_relative_imports
...@@ -32,7 +33,7 @@ from transformers.utils import ( ...@@ -32,7 +33,7 @@ from transformers.utils import (
) )
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from .logging import get_logger from . import logging
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
...@@ -48,7 +49,7 @@ if TYPE_CHECKING: ...@@ -48,7 +49,7 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments from ..hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class AverageMeter: class AverageMeter:
...@@ -76,12 +77,12 @@ def check_dependencies() -> None: ...@@ -76,12 +77,12 @@ def check_dependencies() -> None:
r""" r"""
Checks the version of the required packages. Checks the version of the required packages.
""" """
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2") require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
...@@ -231,18 +232,43 @@ def torch_gc() -> None: ...@@ -231,18 +232,43 @@ def torch_gc() -> None:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def try_download_model_from_ms(model_args: "ModelArguments") -> str: def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if not use_modelscope() or os.path.exists(model_args.model_name_or_path): if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
return model_args.model_name_or_path return model_args.model_name_or_path
try: if use_modelscope():
from modelscope import snapshot_download require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir) return snapshot_download(
except ImportError: model_args.model_name_or_path,
raise ImportError("Please install modelscope via `pip install modelscope -U`") revision=revision,
cache_dir=model_args.cache_dir,
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
model_args.model_name_or_path,
revision=model_args.model_revision,
cache_dir=model_args.cache_dir,
)
def use_modelscope() -> bool: def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
r"""
calculate effective tokens.
"""
result = effective_token_num * epoch / train_runtime
return result / dist.get_world_size() if dist.is_initialized() else result
...@@ -75,8 +75,13 @@ def is_starlette_available(): ...@@ -75,8 +75,13 @@ def is_starlette_available():
@lru_cache @lru_cache
def is_transformers_version_greater_than_4_43(): def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse("4.43.0") return _get_package_version("transformers") >= version.parse(content)
@lru_cache
def is_transformers_version_equal_to_4_46():
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
def is_uvicorn_available(): def is_uvicorn_available():
......
...@@ -19,7 +19,7 @@ from typing import Any, Dict, List ...@@ -19,7 +19,7 @@ from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from .logging import get_logger from . import logging
from .packages import is_matplotlib_available from .packages import is_matplotlib_available
...@@ -28,7 +28,7 @@ if is_matplotlib_available(): ...@@ -28,7 +28,7 @@ if is_matplotlib_available():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]: def smooth(scalars: List[float]) -> List[float]:
...@@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None: ...@@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
Plots loss curves and saves the image. Plots loss curves and saves the image.
""" """
plt.switch_backend("agg") plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
for key in keys: for key in keys:
...@@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None: ...@@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
metrics.append(data["log_history"][i][key]) metrics.append(data["log_history"][i][key])
if len(metrics) == 0: if len(metrics) == 0:
logger.warning(f"No metric {key} to plot.") logger.warning_rank0(f"No metric {key} to plot.")
continue continue
plt.figure() plt.figure()
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title("training {} of {}".format(key, save_dictionary)) plt.title(f"training {key} of {save_dictionary}")
plt.xlabel("step") plt.xlabel("step")
plt.ylabel(key) plt.ylabel(key)
plt.legend() plt.legend()
......
...@@ -41,8 +41,12 @@ class DataArguments: ...@@ -41,8 +41,12 @@ class DataArguments:
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."}, metadata={"help": "Path to the folder containing the datasets."},
) )
image_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
)
cutoff_len: int = field( cutoff_len: int = field(
default=1024, default=2048,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
) )
train_on_prompt: bool = field( train_on_prompt: bool = field(
...@@ -111,7 +115,13 @@ class DataArguments: ...@@ -111,7 +115,13 @@ class DataArguments:
) )
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to save or load the tokenized datasets."}, metadata={
"help": (
"Path to save or load the tokenized datasets. "
"If tokenized_path not exists, it will save the tokenized datasets. "
"If tokenized_path exists, it will load the tokenized datasets."
)
},
) )
def __post_init__(self): def __post_init__(self):
...@@ -123,6 +133,9 @@ class DataArguments: ...@@ -123,6 +133,9 @@ class DataArguments:
self.dataset = split_arg(self.dataset) self.dataset = split_arg(self.dataset)
self.eval_dataset = split_arg(self.eval_dataset) self.eval_dataset = split_arg(self.eval_dataset)
if self.image_dir is None:
self.image_dir = self.dataset_dir
if self.dataset is None and self.val_size > 1e-6: if self.dataset is None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `dataset` is None.") raise ValueError("Cannot specify `val_size` if `dataset` is None.")
......
...@@ -346,6 +346,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA ...@@ -346,6 +346,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to save the training loss curves."}, metadata={"help": "Whether or not to save the training loss curves."},
) )
include_effective_tokens_per_second: bool = field(
default=False,
metadata={"help": "Whether or not to compute effective tokens per second."},
)
def __post_init__(self): def __post_init__(self):
def split_arg(arg): def split_arg(arg):
......
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self from typing_extensions import Self
...@@ -57,12 +59,12 @@ class ProcessorArguments: ...@@ -57,12 +59,12 @@ class ProcessorArguments:
""" """
image_resolution: int = field( image_resolution: int = field(
default=512, default=512 * 512,
metadata={"help": "Keeps the height or width of image below this resolution."}, metadata={"help": "Keeps the number of pixels of image below this resolution."},
) )
video_resolution: int = field( video_resolution: int = field(
default=128, default=128 * 128,
metadata={"help": "Keeps the height or width of video below this resolution."}, metadata={"help": "Keeps the number of pixels of video below this resolution."},
) )
video_fps: float = field( video_fps: float = field(
default=2.0, default=2.0,
...@@ -125,7 +127,7 @@ class VllmArguments: ...@@ -125,7 +127,7 @@ class VllmArguments:
""" """
vllm_maxlen: int = field( vllm_maxlen: int = field(
default=2048, default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."}, metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
) )
vllm_gpu_util: float = field( vllm_gpu_util: float = field(
...@@ -140,6 +142,10 @@ class VllmArguments: ...@@ -140,6 +142,10 @@ class VllmArguments:
default=32, default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
) )
vllm_config: Optional[Union[dict, str]] = field(
default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
)
@dataclass @dataclass
...@@ -267,6 +273,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, ...@@ -267,6 +273,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}, metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
om_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
print_param_status: bool = field( print_param_status: bool = field(
default=False, default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
...@@ -308,6 +318,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, ...@@ -308,6 +318,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if self.export_quantization_bit is not None and self.export_quantization_dataset is None: if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.") raise ValueError("Quantization dataset is necessary for exporting.")
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
@classmethod @classmethod
def copyfrom(cls, source: "Self", **kwargs) -> "Self": def copyfrom(cls, source: "Self", **kwargs) -> "Self":
init_args, lazy_args = {}, {} init_args, lazy_args = {}, {}
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os import os
import sys import sys
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
...@@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode ...@@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
...@@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments ...@@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
...@@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non ...@@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if unknown_args: if unknown_args:
print(parser.format_help()) print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,) return (*parsed_args,)
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: def _set_transformers_logging() -> None:
transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
...@@ -104,7 +103,7 @@ def _verify_model_args( ...@@ -104,7 +103,7 @@ def _verify_model_args(
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer: if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
...@@ -123,7 +122,7 @@ def _check_extra_dependencies( ...@@ -123,7 +122,7 @@ def _check_extra_dependencies(
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<=0.6.2", "To fix: pip install vllm>=0.4.3,<=0.6.2") require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
if finetuning_args.use_galore: if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch") require_version("galore_torch", "To fix: pip install galore_torch")
...@@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ...@@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing: if data_args.neat_packing and not data_args.packing:
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.") logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True data_args.packing = True
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
...@@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ...@@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and model_args.resize_vocab and model_args.resize_vocab
and finetuning_args.additional_target is None and finetuning_args.additional_target is None
): ):
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.") logger.warning_rank0(
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
)
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.") logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.") logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.") logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None: if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.") logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments # Post-process training arguments
if ( if (
...@@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ...@@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
): ):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None: if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.") logger.warning_rank0("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None training_args.resume_from_checkpoint = None
else: else:
can_resume_from_checkpoint = True can_resume_from_checkpoint = True
...@@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ...@@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if last_checkpoint is not None: if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint training_args.resume_from_checkpoint = last_checkpoint
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint)) logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if ( if (
finetuning_args.stage in ["rm", "ppo"] finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None and training_args.resume_from_checkpoint is not None
): ):
logger.warning( logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint training_args.resume_from_checkpoint
) )
......
...@@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model ...@@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
...@@ -33,7 +33,7 @@ if TYPE_CHECKING: ...@@ -33,7 +33,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _setup_full_tuning( def _setup_full_tuning(
...@@ -45,7 +45,7 @@ def _setup_full_tuning( ...@@ -45,7 +45,7 @@ def _setup_full_tuning(
if not is_trainable: if not is_trainable:
return return
logger.info("Fine-tuning method: Full") logger.info_rank0("Fine-tuning method: Full")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules): if not any(forbidden_module in name for forbidden_module in forbidden_modules):
...@@ -64,7 +64,7 @@ def _setup_freeze_tuning( ...@@ -64,7 +64,7 @@ def _setup_freeze_tuning(
if not is_trainable: if not is_trainable:
return return
logger.info("Fine-tuning method: Freeze") logger.info_rank0("Fine-tuning method: Freeze")
if hasattr(model.config, "text_config"): # composite models if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config") config = getattr(model.config, "text_config")
else: else:
...@@ -133,7 +133,7 @@ def _setup_freeze_tuning( ...@@ -133,7 +133,7 @@ def _setup_freeze_tuning(
else: else:
param.requires_grad_(False) param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(trainable_layers))) logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning( def _setup_lora_tuning(
...@@ -145,7 +145,7 @@ def _setup_lora_tuning( ...@@ -145,7 +145,7 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
) -> "PeftModel": ) -> "PeftModel":
if is_trainable: if is_trainable:
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None adapter_to_resume = None
...@@ -182,7 +182,7 @@ def _setup_lora_tuning( ...@@ -182,7 +182,7 @@ def _setup_lora_tuning(
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth: if model_args.use_unsloth:
...@@ -190,7 +190,7 @@ def _setup_lora_tuning( ...@@ -190,7 +190,7 @@ def _setup_lora_tuning(
else: else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
...@@ -219,7 +219,7 @@ def _setup_lora_tuning( ...@@ -219,7 +219,7 @@ def _setup_lora_tuning(
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = { peft_kwargs = {
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,
...@@ -236,11 +236,11 @@ def _setup_lora_tuning( ...@@ -236,11 +236,11 @@ def _setup_lora_tuning(
else: else:
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1: if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.") logger.info_rank0("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa" peft_kwargs["init_lora_weights"] = "pissa"
else: else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
...@@ -284,11 +284,11 @@ def init_adapter( ...@@ -284,11 +284,11 @@ def init_adapter(
if not is_trainable: if not is_trainable:
pass pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam: elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else: else:
logger.info("Upcasting trainable params to float32.") logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full":
...@@ -300,6 +300,6 @@ def init_adapter( ...@@ -300,6 +300,6 @@ def init_adapter(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
) )
else: else:
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
return model return model
...@@ -18,8 +18,8 @@ import torch ...@@ -18,8 +18,8 @@ import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
...@@ -35,7 +35,7 @@ if TYPE_CHECKING: ...@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class TokenizerModule(TypedDict): class TokenizerModule(TypedDict):
...@@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: ...@@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
skip_check_imports() skip_check_imports()
model_args.model_name_or_path = try_download_model_from_ms(model_args) model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
return { return {
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
...@@ -90,17 +90,17 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -90,17 +90,17 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
dict(additional_special_tokens=model_args.new_special_tokens), dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False, replace_additional_special_tokens=False,
) )
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab: if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.") logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
try: try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args) patch_processor(processor, config, tokenizer, model_args)
except Exception as e: except Exception as e:
logger.warning("Processor was not found: {}.".format(e)) logger.debug(f"Processor was not found: {e}.")
processor = None processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
...@@ -153,8 +153,9 @@ def load_model( ...@@ -153,8 +153,9 @@ def load_model(
load_class = AutoModelForVision2Seq load_class = AutoModelForVision2Seq
else: else:
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM
if model_args.train_from_scratch: if model_args.train_from_scratch:
model = load_class.from_config(config) model = load_class.from_config(config, trust_remote_code=True)
else: else:
model = load_class.from_pretrained(**init_kwargs) model = load_class.from_pretrained(**init_kwargs)
...@@ -179,7 +180,7 @@ def load_model( ...@@ -179,7 +180,7 @@ def load_model(
vhead_params = load_valuehead_params(vhead_path, model_args) vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None: if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
...@@ -197,9 +198,9 @@ def load_model( ...@@ -197,9 +198,9 @@ def load_model(
trainable_params, all_param, 100 * trainable_params / all_param trainable_params, all_param, 100 * trainable_params / all_param
) )
else: else:
param_stats = "all params: {:,}".format(all_param) param_stats = f"all params: {all_param:,}"
logger.info(param_stats) logger.info_rank0(param_stats)
if model_args.print_param_status: if model_args.print_param_status:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
......
...@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING ...@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -26,7 +26,7 @@ if TYPE_CHECKING: ...@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def configure_attn_implementation( def configure_attn_implementation(
...@@ -38,13 +38,15 @@ def configure_attn_implementation( ...@@ -38,13 +38,15 @@ def configure_attn_implementation(
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
if model_args.flash_attn != "fa2": if model_args.flash_attn != "fa2":
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2" model_args.flash_attn = "fa2"
else: else:
logger.warning("FlashAttention-2 is not installed, use eager attention.") logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled" model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.") logger.warning_rank0(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if model_args.flash_attn == "auto": if model_args.flash_attn == "auto":
return return
...@@ -54,18 +56,18 @@ def configure_attn_implementation( ...@@ -54,18 +56,18 @@ def configure_attn_implementation(
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
if not is_torch_sdpa_available(): if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.") logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return return
requested_attn_implementation = "sdpa" requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2": elif model_args.flash_attn == "fa2":
if not is_flash_attn_2_available(): if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.") logger.warning_rank0("FlashAttention-2 is not installed.")
return return
requested_attn_implementation = "flash_attention_2" requested_attn_implementation = "flash_attention_2"
else: else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation) setattr(config, "attn_implementation", requested_attn_implementation)
...@@ -80,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None: ...@@ -80,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
attn_implementation = getattr(config, "_attn_implementation", None) attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2": if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.") logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa": elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.") logger.info_rank0("Using torch SDPA for faster training and inference.")
else: else:
logger.info("Using vanilla attention implementation.") logger.info_rank0("Using vanilla attention implementation.")
...@@ -19,14 +19,14 @@ ...@@ -19,14 +19,14 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from functools import partial, wraps from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
from ...extras import logging
from ...extras.constants import LAYERNORM_NAMES from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -35,7 +35,7 @@ if TYPE_CHECKING: ...@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable: def get_unsloth_gradient_checkpointing_func() -> Callable:
...@@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable ...@@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
Only applies gradient checkpointing to trainable layers. Only applies gradient checkpointing to trainable layers.
""" """
@wraps(gradient_checkpointing_func) @wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__ module: "torch.nn.Module" = func.__self__
...@@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable ...@@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
return gradient_checkpointing_func(func, *args, **kwargs) return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func return custom_gradient_checkpointing_func
...@@ -111,7 +108,7 @@ def _gradient_checkpointing_enable( ...@@ -111,7 +108,7 @@ def _gradient_checkpointing_enable(
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing: if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None: if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_kwargs = {"use_reentrant": True}
...@@ -125,7 +122,7 @@ def _gradient_checkpointing_enable( ...@@ -125,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads() self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
...@@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum ...@@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
(3) add the upcasting of the lm_head in fp32 (3) add the upcasting of the lm_head in fp32
""" """
if model_args.upcast_layernorm: if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.") logger.info_rank0("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing: if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False): if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.") logger.warning_rank0("Current model does not support gradient checkpointing.")
else: else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339 # According to: https://github.com/huggingface/transformers/issues/28339
...@@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum ...@@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info_rank0("Gradient checkpointing enabled.")
if model_args.upcast_lmhead_output: if model_args.upcast_lmhead_output:
output_layer = model.get_output_embeddings() output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
logger.info("Upcasting lm_head outputs in float32.") logger.info_rank0("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook) output_layer.register_forward_hook(_fp32_forward_post_hook)
...@@ -19,14 +19,14 @@ from typing import TYPE_CHECKING ...@@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
...@@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken ...@@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment