Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
2778a3d0
Commit
2778a3d0
authored
Jan 16, 2025
by
luopl
Browse files
updata to v0.9.1_stable
parent
e92143e3
Changes
172
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
476 additions
and
151 deletions
+476
-151
src/llamafactory/data/processors/unsupervised.py
src/llamafactory/data/processors/unsupervised.py
+5
-3
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+84
-18
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+1
-1
src/llamafactory/eval/evaluator.py
src/llamafactory/eval/evaluator.py
+2
-2
src/llamafactory/eval/template.py
src/llamafactory/eval/template.py
+1
-1
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+151
-10
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+2
-2
src/llamafactory/extras/logging.py
src/llamafactory/extras/logging.py
+63
-11
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+40
-14
src/llamafactory/extras/packages.py
src/llamafactory/extras/packages.py
+7
-2
src/llamafactory/extras/ploting.py
src/llamafactory/extras/ploting.py
+5
-5
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+15
-2
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+4
-0
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+18
-5
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+24
-21
src/llamafactory/model/adapter.py
src/llamafactory/model/adapter.py
+16
-16
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+12
-11
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+13
-11
src/llamafactory/model/model_utils/checkpointing.py
src/llamafactory/model/model_utils/checkpointing.py
+10
-13
src/llamafactory/model/model_utils/embedding.py
src/llamafactory/model/model_utils/embedding.py
+3
-3
No files found.
src/llamafactory/data/processors/unsupervised.py
View file @
2778a3d0
...
...
@@ -15,7 +15,7 @@
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
from
..data_utils
import
Role
from
.processor_utils
import
infer_seqlen
...
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
..template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_unsupervised_example
(
...
...
@@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_unsupervised_example
(
...
...
src/llamafactory/data/template.py
View file @
2778a3d0
...
...
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from
transformers.utils.versions
import
require_version
from
typing_extensions
import
override
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
from
.data_utils
import
Role
from
.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
from
.mm_plugin
import
get_mm_plugin
...
...
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
from
.mm_plugin
import
BasePlugin
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
...
...
@@ -147,7 +147,7 @@ class Template:
elif
"eos_token"
in
elem
and
tokenizer
.
eos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
eos_token_id
]
else
:
raise
ValueError
(
"Input must be string, set[str] or dict[str, str], got {
}"
.
format
(
type
(
elem
)
)
)
raise
ValueError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
elem
)
}
"
)
return
token_ids
...
...
@@ -275,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
if
is_added
:
logger
.
info
(
"Add eos token: {
}"
.
format
(
tokenizer
.
eos_token
)
)
logger
.
info
_rank0
(
f
"Add eos token:
{
tokenizer
.
eos_token
}
"
)
else
:
logger
.
info
(
"Replace eos token: {
}"
.
format
(
tokenizer
.
eos_token
)
)
logger
.
info
_rank0
(
f
"Replace eos token:
{
tokenizer
.
eos_token
}
"
)
if
num_added_tokens
>
0
:
logger
.
warning
(
"New tokens have been added, make sure `resize_vocab` is True."
)
logger
.
warning
_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
_jinja_escape
(
content
:
str
)
->
str
:
...
...
@@ -356,22 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
r
"""
Gets chat template and fixes the tokenizer.
"""
if
data_args
.
template
in
[
"llava"
,
"paligemma"
,
"qwen2_vl"
]:
require_version
(
"transformers>=4.45.0"
,
"To fix: pip install transformers>=4.45.0"
)
require_version
(
"accelerate>=0.34.0"
,
"To fix: pip install accelerate>=0.34.0"
)
if
data_args
.
template
is
None
:
template
=
TEMPLATES
[
"empty"
]
# placeholder
else
:
template
=
TEMPLATES
.
get
(
data_args
.
template
,
None
)
if
template
is
None
:
raise
ValueError
(
"Template {} does not exist."
.
format
(
data_args
.
template
))
raise
ValueError
(
f
"Template
{
data_args
.
template
}
does not exist."
)
if
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
require_version
(
"transformers>=4.45.0"
,
"To fix: pip install transformers>=4.45.0"
)
if
data_args
.
train_on_prompt
and
template
.
efficient_eos
:
raise
ValueError
(
"Current template does not support `train_on_prompt`."
)
if
data_args
.
tool_format
is
not
None
:
logger
.
info
(
"Using tool format: {
}."
.
format
(
data_args
.
tool_format
)
)
logger
.
info
_rank0
(
f
"Using tool format:
{
data_args
.
tool_format
}
."
)
eos_slots
=
[]
if
template
.
efficient_eos
else
[{
"eos_token"
}]
template
.
format_function
=
FunctionFormatter
(
slots
=
eos_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
...
...
@@ -389,21 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if
tokenizer
.
pad_token_id
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
logger
.
info
(
"Add pad token: {
}"
.
format
(
tokenizer
.
pad_token
)
)
logger
.
info
_rank0
(
f
"Add pad token:
{
tokenizer
.
pad_token
}
"
)
if
stop_words
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
logger
.
info
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
logger
.
info
_rank0
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
if
num_added_tokens
>
0
:
logger
.
warning
(
"New tokens have been added, make sure `resize_vocab` is True."
)
logger
.
warning
_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
if
template
.
replace_jinja_template
:
if
tokenizer
.
chat_template
is
None
or
template
.
replace_jinja_template
:
try
:
tokenizer
.
chat_template
=
_get_jinja_template
(
template
,
tokenizer
)
except
ValueError
:
logger
.
info
(
"Cannot add this chat template to tokenizer."
)
except
ValueError
as
e
:
logger
.
info
_rank0
(
f
"Cannot add this chat template to tokenizer
:
{
e
}
."
)
return
template
...
...
@@ -692,6 +691,14 @@ _register_template(
)
_register_template
(
name
=
"index"
,
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
format_system
=
StringFormatter
(
slots
=
[
"<unk>{{content}}"
]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"intern"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
...
...
@@ -755,6 +762,33 @@ _register_template(
)
_register_template
(
name
=
"mllama"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>user<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>tool<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
mm_plugin
=
get_mm_plugin
(
name
=
"mllama"
,
image_token
=
"<|image|>"
),
)
_register_template
(
name
=
"llava"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
...
...
@@ -904,6 +938,19 @@ _register_template(
)
_register_template
(
name
=
"opencoder"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
"You are OpenCoder, created by OpenCoder Team."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
)
_register_template
(
name
=
"orion"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
,
{
"eos_token"
}]),
...
...
@@ -935,6 +982,25 @@ _register_template(
)
_register_template
(
name
=
"phi_small"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"<|endoftext|>"
}]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"pixtral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST]"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
),
)
_register_template
(
name
=
"qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
...
...
src/llamafactory/data/tool_utils.py
View file @
2778a3d0
...
...
@@ -177,6 +177,6 @@ TOOLS = {
def
get_tool_utils
(
name
:
str
)
->
"ToolUtils"
:
tool_utils
=
TOOLS
.
get
(
name
,
None
)
if
tool_utils
is
None
:
raise
ValueError
(
"Tool utils `{}` not found."
.
format
(
name
)
)
raise
ValueError
(
f
"Tool utils `
{
name
}
` not found."
)
return
tool_utils
src/llamafactory/eval/evaluator.py
View file @
2778a3d0
...
...
@@ -87,7 +87,7 @@ class Evaluator:
token
=
self
.
model_args
.
hf_hub_token
,
)
with
open
(
mapping
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
mapping
,
encoding
=
"utf-8"
)
as
f
:
categorys
:
Dict
[
str
,
Dict
[
str
,
str
]]
=
json
.
load
(
f
)
category_corrects
=
{
subj
:
np
.
array
([],
dtype
=
"bool"
)
for
subj
in
SUBJECTS
}
...
...
@@ -139,7 +139,7 @@ class Evaluator:
def
_save_results
(
self
,
category_corrects
:
Dict
[
str
,
"NDArray"
],
results
:
Dict
[
str
,
Dict
[
int
,
str
]])
->
None
:
score_info
=
"
\n
"
.
join
(
[
"{
:>15}: {:.2f}"
.
format
(
category_name
,
100
*
np
.
mean
(
category_correct
)
)
f
"
{
category_name
:
>
15
}
:
{
100
*
np
.
mean
(
category_correct
)
:.
2
f
}
"
for
category_name
,
category_correct
in
category_corrects
.
items
()
if
len
(
category_correct
)
]
...
...
src/llamafactory/eval/template.py
View file @
2778a3d0
...
...
@@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
def
get_eval_template
(
name
:
str
)
->
"EvalTemplate"
:
eval_template
=
eval_templates
.
get
(
name
,
None
)
assert
eval_template
is
not
None
,
"Template {} does not exist."
.
format
(
name
)
assert
eval_template
is
not
None
,
f
"Template
{
name
}
does not exist."
return
eval_template
...
...
src/llamafactory/extras/constants.py
View file @
2778a3d0
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
collections
import
OrderedDict
,
defaultdict
from
enum
import
Enum
from
typing
import
Dict
,
Optional
...
...
@@ -47,7 +48,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX
=
-
100
IMAGE_PLACEHOLDER
=
"<image>"
IMAGE_PLACEHOLDER
=
os
.
environ
.
get
(
"IMAGE_PLACEHOLDER"
,
"<image>"
)
LAYERNORM_NAMES
=
{
"norm"
,
"ln"
}
...
...
@@ -95,7 +96,7 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN
=
{
"llama"
}
VIDEO_PLACEHOLDER
=
"<video>"
VIDEO_PLACEHOLDER
=
os
.
environ
.
get
(
"VIDEO_PLACEHOLDER"
,
"<video>"
)
V_HEAD_WEIGHTS_NAME
=
"value_head.bin"
...
...
@@ -107,6 +108,7 @@ VISION_MODELS = set()
class
DownloadSource
(
str
,
Enum
):
DEFAULT
=
"hf"
MODELSCOPE
=
"ms"
OPENMIND
=
"om"
def
register_model_group
(
...
...
@@ -163,14 +165,17 @@ register_model_group(
"Baichuan2-13B-Base"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan2-13B-Base"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan2-13B-Base"
,
DownloadSource
.
OPENMIND
:
"Baichuan/Baichuan2_13b_base_pt"
,
},
"Baichuan2-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan2-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan2-7B-Chat"
,
DownloadSource
.
OPENMIND
:
"Baichuan/Baichuan2_7b_chat_pt"
,
},
"Baichuan2-13B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan2-13B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan2-13B-Chat"
,
DownloadSource
.
OPENMIND
:
"Baichuan/Baichuan2_13b_chat_pt"
,
},
},
template
=
"baichuan2"
,
...
...
@@ -555,10 +560,12 @@ register_model_group(
"Gemma-2-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-2b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-2b-it"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/gemma-2-2b-it"
,
},
"Gemma-2-9B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-9b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-9b-it"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/gemma-2-9b-it"
,
},
"Gemma-2-27B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b-it"
,
...
...
@@ -578,6 +585,7 @@ register_model_group(
"GLM-4-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/glm-4-9b-chat"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/glm-4-9b-chat"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/glm-4-9b-chat"
,
},
"GLM-4-9B-1M-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/glm-4-9b-chat-1m"
,
...
...
@@ -588,6 +596,33 @@ register_model_group(
)
register_model_group
(
models
=
{
"Index-1.9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-Chat"
,
},
"Index-1.9B-Character-Chat"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-Character"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-Character"
,
},
"Index-1.9B-Base"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B"
,
},
"Index-1.9B-Base-Pure"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-Pure"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-Pure"
,
},
"Index-1.9B-Chat-32K"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-32K"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-32K"
,
},
},
template
=
"index"
,
)
register_model_group
(
models
=
{
"InternLM-7B"
:
{
...
...
@@ -632,6 +667,7 @@ register_model_group(
"InternLM2.5-1.8B"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-1_8b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-1_8b"
,
DownloadSource
.
OPENMIND
:
"Intern/internlm2_5-1_8b"
,
},
"InternLM2.5-7B"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-7b"
,
...
...
@@ -640,22 +676,27 @@ register_model_group(
"InternLM2.5-20B"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-20b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-20b"
,
DownloadSource
.
OPENMIND
:
"Intern/internlm2_5-20b"
,
},
"InternLM2.5-1.8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-1_8b-chat"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-1_8b-chat"
,
DownloadSource
.
OPENMIND
:
"Intern/internlm2_5-1_8b-chat"
,
},
"InternLM2.5-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-7b-chat"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-7b-chat"
,
DownloadSource
.
OPENMIND
:
"Intern/internlm2_5-7b-chat"
,
},
"InternLM2.5-7B-1M-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-7b-chat-1m"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m"
,
DownloadSource
.
OPENMIND
:
"Intern/internlm2_5-7b-chat-1m"
,
},
"InternLM2.5-20B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-20b-chat"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-20b-chat"
,
DownloadSource
.
OPENMIND
:
"Intern/internlm2_5-20b-chat"
,
},
},
template
=
"intern2"
,
...
...
@@ -756,6 +797,7 @@ register_model_group(
"Llama-3-8B-Chinese-Chat"
:
{
DownloadSource
.
DEFAULT
:
"shenzhi-wang/Llama3-8B-Chinese-Chat"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama3-8B-Chinese-Chat"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Llama3-Chinese-8B-Instruct"
,
},
"Llama-3-70B-Chinese-Chat"
:
{
DownloadSource
.
DEFAULT
:
"shenzhi-wang/Llama3-70B-Chinese-Chat"
,
...
...
@@ -813,6 +855,22 @@ register_model_group(
)
register_model_group
(
models
=
{
"Llama-3.2-11B-Vision-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-3.2-11B-Vision-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-3.2-11B-Vision-Instruct"
,
},
"Llama-3.2-90B-Vision-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-3.2-90B-Vision-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-3.2-90B-Vision-Instruct"
,
},
},
template
=
"mllama"
,
vision
=
True
,
)
register_model_group
(
models
=
{
"LLaVA-1.5-7B-Chat"
:
{
...
...
@@ -960,6 +1018,7 @@ register_model_group(
"MiniCPM3-4B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM3-4B"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM3-4B"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/MiniCPM3-4B"
,
},
},
template
=
"cpm3"
,
...
...
@@ -1062,6 +1121,29 @@ register_model_group(
)
register_model_group
(
models
=
{
"OpenCoder-1.5B-Base"
:
{
DownloadSource
.
DEFAULT
:
"infly/OpenCoder-1.5B-Base"
,
DownloadSource
.
MODELSCOPE
:
"infly/OpenCoder-1.5B-Base"
,
},
"OpenCoder-8B-Base"
:
{
DownloadSource
.
DEFAULT
:
"infly/OpenCoder-8B-Base"
,
DownloadSource
.
MODELSCOPE
:
"infly/OpenCoder-8B-Base"
,
},
"OpenCoder-1.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"infly/OpenCoder-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"infly/OpenCoder-1.5B-Instruct"
,
},
"OpenCoder-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"infly/OpenCoder-8B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"infly/OpenCoder-8B-Instruct"
,
},
},
template
=
"opencoder"
,
)
register_model_group
(
models
=
{
"Orion-14B-Base"
:
{
...
...
@@ -1141,14 +1223,6 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-mini-128k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-mini-128k-instruct"
,
},
"Phi-3-7B-8k-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-small-8k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-small-8k-instruct"
,
},
"Phi-3-7B-128k-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-small-128k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-small-128k-instruct"
,
},
"Phi-3-14B-8k-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-medium-4k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-medium-4k-instruct"
,
...
...
@@ -1162,6 +1236,33 @@ register_model_group(
)
register_model_group
(
models
=
{
"Phi-3-7B-8k-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-small-8k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-small-8k-instruct"
,
},
"Phi-3-7B-128k-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-small-128k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-small-128k-instruct"
,
},
},
template
=
"phi_small"
,
)
register_model_group
(
models
=
{
"Pixtral-12B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistral-community/pixtral-12b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/pixtral-12b"
,
}
},
template
=
"pixtral"
,
vision
=
True
,
)
register_model_group
(
models
=
{
"Qwen-1.8B"
:
{
...
...
@@ -1409,14 +1510,17 @@ register_model_group(
"Qwen2-0.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-0.5B-Instruct"
,
},
"Qwen2-1.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-1.5B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-1.5B-Instruct"
,
},
"Qwen2-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-7B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-7B-Instruct"
,
},
"Qwen2-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B-Instruct"
,
...
...
@@ -1649,22 +1753,54 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-72B-Instruct-AWQ"
,
},
"Qwen2.5-Coder-0.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-0.5B"
,
},
"Qwen2.5-Coder-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-1.5B"
,
},
"Qwen2.5-Coder-3B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-3B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-3B"
,
},
"Qwen2.5-Coder-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-7B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-7B"
,
},
"Qwen2.5-Coder-14B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-14B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-14B"
,
},
"Qwen2.5-Coder-32B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-32B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-32B"
,
},
"Qwen2.5-Coder-0.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-0.5B-Instruct"
,
},
"Qwen2.5-Coder-1.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-1.5B-Instruct"
,
},
"Qwen2.5-Coder-3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-3B-Instruct"
,
},
"Qwen2.5-Coder-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-7B-Instruct"
,
},
"Qwen2.5-Coder-14B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-14B-Instruct"
,
},
"Qwen2.5-Coder-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-32B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-32B-Instruct"
,
},
"Qwen2.5-Math-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Math-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Math-1.5B"
,
...
...
@@ -1699,10 +1835,12 @@ register_model_group(
"Qwen2-VL-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-2B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-VL-2B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-VL-2B-Instruct"
,
},
"Qwen2-VL-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-VL-7B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-VL-7B-Instruct"
,
},
"Qwen2-VL-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-72B-Instruct"
,
...
...
@@ -1801,10 +1939,12 @@ register_model_group(
"TeleChat-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/telechat-7B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/telechat-7B"
,
DownloadSource
.
OPENMIND
:
"TeleAI/TeleChat-7B-pt"
,
},
"TeleChat-12B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat-12B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat-12B"
,
DownloadSource
.
OPENMIND
:
"TeleAI/TeleChat-12B-pt"
,
},
"TeleChat-12B-v2-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat-12B-v2"
,
...
...
@@ -2023,6 +2163,7 @@ register_model_group(
"Yi-1.5-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-1.5-6B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-1.5-6B-Chat"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Yi-1.5-6B-Chat"
,
},
"Yi-1.5-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-1.5-9B-Chat"
,
...
...
src/llamafactory/extras/env.py
View file @
2778a3d0
...
...
@@ -26,7 +26,7 @@ import trl
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
VERSION
=
"0.9.1
.dev0
"
VERSION
=
"0.9.1"
def
print_env
()
->
None
:
...
...
@@ -72,4 +72,4 @@ def print_env() -> None:
except
Exception
:
pass
print
(
"
\n
"
+
"
\n
"
.
join
([
"- {}: {
}"
.
format
(
key
,
value
)
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
print
(
"
\n
"
+
"
\n
"
.
join
([
f
"-
{
key
}
:
{
value
}
"
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
src/llamafactory/extras/logging.py
View file @
2778a3d0
...
...
@@ -20,6 +20,7 @@ import os
import
sys
import
threading
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
lru_cache
from
typing
import
Optional
from
.constants
import
RUNNING_LOG
...
...
@@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
def
__init__
(
self
,
output_dir
:
str
)
->
None
:
super
().
__init__
()
formatter
=
logging
.
Formatter
(
fmt
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
self
.
_formatter
=
logging
.
Formatter
(
fmt
=
"[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
self
.
setLevel
(
logging
.
INFO
)
self
.
setFormatter
(
formatter
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
self
.
running_log
=
os
.
path
.
join
(
output_dir
,
RUNNING_LOG
)
if
os
.
path
.
exists
(
self
.
running_log
):
...
...
@@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
if
record
.
name
==
"httpx"
:
return
log_entry
=
self
.
format
(
record
)
log_entry
=
self
.
_formatter
.
format
(
record
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
log_entry
)
def
close
(
self
)
->
None
:
...
...
@@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
return
super
().
close
()
class
_Logger
(
logging
.
Logger
):
r
"""
A logger that supports info_rank0 and warning_once.
"""
def
info_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
info
(
*
args
,
**
kwargs
)
def
warning_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
warning
(
*
args
,
**
kwargs
)
def
warning_once
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
warning
(
*
args
,
**
kwargs
)
def
_get_default_logging_level
()
->
"logging._Level"
:
r
"""
Returns the default logging level.
...
...
@@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
if
env_level_str
.
upper
()
in
logging
.
_nameToLevel
:
return
logging
.
_nameToLevel
[
env_level_str
.
upper
()]
else
:
raise
ValueError
(
"Unknown logging level: {
}."
.
format
(
env_level_str
)
)
raise
ValueError
(
f
"Unknown logging level:
{
env_level_str
}
."
)
return
_default_log_level
...
...
@@ -84,7 +99,7 @@ def _get_library_name() -> str:
return
__name__
.
split
(
"."
)[
0
]
def
_get_library_root_logger
()
->
"
logging.
Logger"
:
def
_get_library_root_logger
()
->
"
_
Logger"
:
return
logging
.
getLogger
(
_get_library_name
())
...
...
@@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
global
_default_handler
with
_thread_lock
:
if
_default_handler
:
if
_default_handler
:
# already configured
return
formatter
=
logging
.
Formatter
(
fmt
=
"%(asctime)s
-
%(
level
name)s
- %(name
)s
-
%(message)s"
,
datefmt
=
"%
m/%d/%Y
%H:%M:%S"
,
fmt
=
"
[%(levelname)s|
%(asctime)s
]
%(name)s
:%(lineno
)s
>>
%(message)s"
,
datefmt
=
"%
Y-%m-%d
%H:%M:%S"
,
)
_default_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
_default_handler
.
setFormatter
(
formatter
)
...
...
@@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
library_root_logger
.
propagate
=
False
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
"
logging.
Logger"
:
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
"
_
Logger"
:
r
"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
...
...
@@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
_configure_library_root_logger
()
return
logging
.
getLogger
(
name
)
def
add_handler
(
handler
:
"logging.Handler"
)
->
None
:
r
"""
Adds a handler to the root logger.
"""
_configure_library_root_logger
()
_get_library_root_logger
().
addHandler
(
handler
)
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
r
"""
Removes a handler to the root logger.
"""
_configure_library_root_logger
()
_get_library_root_logger
().
removeHandler
(
handler
)
def
info_rank0
(
self
:
"logging.Logger"
,
*
args
,
**
kwargs
)
->
None
:
if
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
==
0
:
self
.
info
(
*
args
,
**
kwargs
)
def
warning_rank0
(
self
:
"logging.Logger"
,
*
args
,
**
kwargs
)
->
None
:
if
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
==
0
:
self
.
warning
(
*
args
,
**
kwargs
)
@
lru_cache
(
None
)
def
warning_once
(
self
:
"logging.Logger"
,
*
args
,
**
kwargs
)
->
None
:
if
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
==
0
:
self
.
warning
(
*
args
,
**
kwargs
)
logging
.
Logger
.
info_rank0
=
info_rank0
logging
.
Logger
.
warning_rank0
=
warning_rank0
logging
.
Logger
.
warning_once
=
warning_once
src/llamafactory/extras/misc.py
View file @
2778a3d0
...
...
@@ -20,6 +20,7 @@ import os
from
typing
import
TYPE_CHECKING
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
import
transformers.dynamic_module_utils
from
transformers
import
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
from
transformers.dynamic_module_utils
import
get_relative_imports
...
...
@@ -32,7 +33,7 @@ from transformers.utils import (
)
from
transformers.utils.versions
import
require_version
from
.
logging
import
get_
logg
er
from
.
import
logg
ing
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
...
...
@@ -48,7 +49,7 @@ if TYPE_CHECKING:
from
..hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
AverageMeter
:
...
...
@@ -76,12 +77,12 @@ def check_dependencies() -> None:
r
"""
Checks the version of the required packages.
"""
if
os
.
environ
.
get
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
logger
.
warning
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
if
os
.
get
env
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
logger
.
warning
_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
else
:
require_version
(
"transformers>=4.41.2,<=4.4
5.2
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
5.2
"
)
require_version
(
"datasets>=2.16.0,<=
2.2
1.0"
,
"To fix: pip install datasets>=2.16.0,<=
2.2
1.0"
)
require_version
(
"accelerate>=0.3
0.1,<=0.34.2
"
,
"To fix: pip install accelerate>=0.3
0.1,<=0.34.2
"
)
require_version
(
"transformers>=4.41.2,<=4.4
6.1
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
6.1
"
)
require_version
(
"datasets>=2.16.0,<=
3.
1.0"
,
"To fix: pip install datasets>=2.16.0,<=
3.
1.0"
)
require_version
(
"accelerate>=0.3
4.0,<=1.0.1
"
,
"To fix: pip install accelerate>=0.3
4.0,<=1.0.1
"
)
require_version
(
"peft>=0.11.1,<=0.12.0"
,
"To fix: pip install peft>=0.11.1,<=0.12.0"
)
require_version
(
"trl>=0.8.6,<=0.9.6"
,
"To fix: pip install trl>=0.8.6,<=0.9.6"
)
...
...
@@ -231,18 +232,43 @@ def torch_gc() -> None:
torch
.
cuda
.
empty_cache
()
def
try_download_model_from_
ms
(
model_args
:
"ModelArguments"
)
->
str
:
if
not
use_modelscope
()
or
os
.
path
.
exists
(
model_args
.
model_name_or_path
):
def
try_download_model_from_
other_hub
(
model_args
:
"ModelArguments"
)
->
str
:
if
(
not
use_modelscope
()
and
not
use_openmind
())
or
os
.
path
.
exists
(
model_args
.
model_name_or_path
):
return
model_args
.
model_name_or_path
try
:
from
modelscope
import
snapshot_download
if
use_modelscope
():
require_version
(
"modelscope>=1.11.0"
,
"To fix: pip install modelscope>=1.11.0"
)
from
modelscope
import
snapshot_download
# type: ignore
revision
=
"master"
if
model_args
.
model_revision
==
"main"
else
model_args
.
model_revision
return
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
revision
,
cache_dir
=
model_args
.
cache_dir
)
except
ImportError
:
raise
ImportError
(
"Please install modelscope via `pip install modelscope -U`"
)
return
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
revision
,
cache_dir
=
model_args
.
cache_dir
,
)
if
use_openmind
():
require_version
(
"openmind>=0.8.0"
,
"To fix: pip install openmind>=0.8.0"
)
from
openmind.utils.hub
import
snapshot_download
# type: ignore
return
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
model_args
.
model_revision
,
cache_dir
=
model_args
.
cache_dir
,
)
def
use_modelscope
()
->
bool
:
return
os
.
environ
.
get
(
"USE_MODELSCOPE_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
def
use_openmind
()
->
bool
:
return
os
.
environ
.
get
(
"USE_OPENMIND_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
def
cal_effective_tokens
(
effective_token_num
,
epoch
,
train_runtime
)
->
int
:
r
"""
calculate effective tokens.
"""
result
=
effective_token_num
*
epoch
/
train_runtime
return
result
/
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
result
src/llamafactory/extras/packages.py
View file @
2778a3d0
...
...
@@ -75,8 +75,13 @@ def is_starlette_available():
@
lru_cache
def
is_transformers_version_greater_than_4_43
():
return
_get_package_version
(
"transformers"
)
>=
version
.
parse
(
"4.43.0"
)
def
is_transformers_version_greater_than
(
content
:
str
):
return
_get_package_version
(
"transformers"
)
>=
version
.
parse
(
content
)
@
lru_cache
def
is_transformers_version_equal_to_4_46
():
return
version
.
parse
(
"4.46.0"
)
<=
_get_package_version
(
"transformers"
)
<=
version
.
parse
(
"4.46.1"
)
def
is_uvicorn_available
():
...
...
src/llamafactory/extras/ploting.py
View file @
2778a3d0
...
...
@@ -19,7 +19,7 @@ from typing import Any, Dict, List
from
transformers.trainer
import
TRAINER_STATE_NAME
from
.
logging
import
get_
logg
er
from
.
import
logg
ing
from
.packages
import
is_matplotlib_available
...
...
@@ -28,7 +28,7 @@ if is_matplotlib_available():
import
matplotlib.pyplot
as
plt
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
smooth
(
scalars
:
List
[
float
])
->
List
[
float
]:
...
...
@@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
Plots loss curves and saves the image.
"""
plt
.
switch_backend
(
"agg"
)
with
open
(
os
.
path
.
join
(
save_dictionary
,
TRAINER_STATE_NAME
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
save_dictionary
,
TRAINER_STATE_NAME
),
encoding
=
"utf-8"
)
as
f
:
data
=
json
.
load
(
f
)
for
key
in
keys
:
...
...
@@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
metrics
.
append
(
data
[
"log_history"
][
i
][
key
])
if
len
(
metrics
)
==
0
:
logger
.
warning
(
f
"No metric
{
key
}
to plot."
)
logger
.
warning
_rank0
(
f
"No metric
{
key
}
to plot."
)
continue
plt
.
figure
()
plt
.
plot
(
steps
,
metrics
,
color
=
"#1f77b4"
,
alpha
=
0.4
,
label
=
"original"
)
plt
.
plot
(
steps
,
smooth
(
metrics
),
color
=
"#1f77b4"
,
label
=
"smoothed"
)
plt
.
title
(
"training {} of {
}"
.
format
(
key
,
save_dictionary
)
)
plt
.
title
(
f
"training
{
key
}
of
{
save_dictionary
}
"
)
plt
.
xlabel
(
"step"
)
plt
.
ylabel
(
key
)
plt
.
legend
()
...
...
src/llamafactory/hparams/data_args.py
View file @
2778a3d0
...
...
@@ -41,8 +41,12 @@ class DataArguments:
default
=
"data"
,
metadata
=
{
"help"
:
"Path to the folder containing the datasets."
},
)
image_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the folder containing the images or videos. Defaults to `dataset_dir`."
},
)
cutoff_len
:
int
=
field
(
default
=
1024
,
default
=
2048
,
metadata
=
{
"help"
:
"The cutoff length of the tokenized inputs in the dataset."
},
)
train_on_prompt
:
bool
=
field
(
...
...
@@ -111,7 +115,13 @@ class DataArguments:
)
tokenized_path
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to save or load the tokenized datasets."
},
metadata
=
{
"help"
:
(
"Path to save or load the tokenized datasets. "
"If tokenized_path not exists, it will save the tokenized datasets. "
"If tokenized_path exists, it will load the tokenized datasets."
)
},
)
def
__post_init__
(
self
):
...
...
@@ -123,6 +133,9 @@ class DataArguments:
self
.
dataset
=
split_arg
(
self
.
dataset
)
self
.
eval_dataset
=
split_arg
(
self
.
eval_dataset
)
if
self
.
image_dir
is
None
:
self
.
image_dir
=
self
.
dataset_dir
if
self
.
dataset
is
None
and
self
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `dataset` is None."
)
...
...
src/llamafactory/hparams/finetuning_args.py
View file @
2778a3d0
...
...
@@ -346,6 +346,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to save the training loss curves."
},
)
include_effective_tokens_per_second
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to compute effective tokens per second."
},
)
def
__post_init__
(
self
):
def
split_arg
(
arg
):
...
...
src/llamafactory/hparams/model_args.py
View file @
2778a3d0
...
...
@@ -15,10 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Union
import
torch
from
transformers.training_args
import
_convert_str_dict
from
typing_extensions
import
Self
...
...
@@ -57,12 +59,12 @@ class ProcessorArguments:
"""
image_resolution
:
int
=
field
(
default
=
512
,
metadata
=
{
"help"
:
"Keeps the
height or width
of image below this resolution."
},
default
=
512
*
512
,
metadata
=
{
"help"
:
"Keeps the
number of pixels
of image below this resolution."
},
)
video_resolution
:
int
=
field
(
default
=
128
,
metadata
=
{
"help"
:
"Keeps the
height or width
of video below this resolution."
},
default
=
128
*
128
,
metadata
=
{
"help"
:
"Keeps the
number of pixels
of video below this resolution."
},
)
video_fps
:
float
=
field
(
default
=
2.0
,
...
...
@@ -125,7 +127,7 @@ class VllmArguments:
"""
vllm_maxlen
:
int
=
field
(
default
=
2048
,
default
=
4096
,
metadata
=
{
"help"
:
"Maximum sequence (prompt + response) length of the vLLM engine."
},
)
vllm_gpu_util
:
float
=
field
(
...
...
@@ -140,6 +142,10 @@ class VllmArguments:
default
=
32
,
metadata
=
{
"help"
:
"Maximum rank of all LoRAs in the vLLM engine."
},
)
vllm_config
:
Optional
[
Union
[
dict
,
str
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the vllm engine. Please use JSON strings."
},
)
@
dataclass
...
...
@@ -267,6 +273,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with ModelScope Hub."
},
)
om_hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with Modelers Hub."
},
)
print_param_status
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"For debugging purposes, print the status of the parameters in the model."
},
...
...
@@ -308,6 +318,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if
self
.
export_quantization_bit
is
not
None
and
self
.
export_quantization_dataset
is
None
:
raise
ValueError
(
"Quantization dataset is necessary for exporting."
)
if
isinstance
(
self
.
vllm_config
,
str
)
and
self
.
vllm_config
.
startswith
(
"{"
):
self
.
vllm_config
=
_convert_str_dict
(
json
.
loads
(
self
.
vllm_config
))
@
classmethod
def
copyfrom
(
cls
,
source
:
"Self"
,
**
kwargs
)
->
"Self"
:
init_args
,
lazy_args
=
{},
{}
...
...
src/llamafactory/hparams/parser.py
View file @
2778a3d0
...
...
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
import
sys
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
...
...
@@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
from
transformers.utils
import
is_torch_bf16_gpu_available
,
is_torch_npu_available
from
transformers.utils.versions
import
require_version
from
..extras
import
logging
from
..extras.constants
import
CHECKPOINT_NAMES
from
..extras.logging
import
get_logger
from
..extras.misc
import
check_dependencies
,
get_current_device
from
.data_args
import
DataArguments
from
.evaluation_args
import
EvaluationArguments
...
...
@@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
from
.model_args
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
check_dependencies
()
...
...
@@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if
unknown_args
:
print
(
parser
.
format_help
())
print
(
"Got unknown args, potentially deprecated arguments: {
}"
.
format
(
unknown_args
)
)
raise
ValueError
(
"Some specified arguments are not used by the HfArgumentParser: {
}"
.
format
(
unknown_args
)
)
print
(
f
"Got unknown args, potentially deprecated arguments:
{
unknown_args
}
"
)
raise
ValueError
(
f
"Some specified arguments are not used by the HfArgumentParser:
{
unknown_args
}
"
)
return
(
*
parsed_args
,)
def
_set_transformers_logging
(
log_level
:
Optional
[
int
]
=
logging
.
INFO
)
->
None
:
transformers
.
utils
.
logging
.
set_verbosity
(
log_level
)
def
_set_transformers_logging
()
->
None
:
transformers
.
utils
.
logging
.
set_verbosity
_info
(
)
transformers
.
utils
.
logging
.
enable_default_handler
()
transformers
.
utils
.
logging
.
enable_explicit_format
()
...
...
@@ -104,7 +103,7 @@ def _verify_model_args(
raise
ValueError
(
"Quantized model only accepts a single adapter. Merge them first."
)
if
data_args
.
template
==
"yi"
and
model_args
.
use_fast_tokenizer
:
logger
.
warning
(
"We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False."
)
logger
.
warning
_rank0
(
"We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False."
)
model_args
.
use_fast_tokenizer
=
False
...
...
@@ -123,7 +122,7 @@ def _check_extra_dependencies(
require_version
(
"mixture-of-depth>=1.1.6"
,
"To fix: pip install mixture-of-depth>=1.1.6"
)
if
model_args
.
infer_backend
==
"vllm"
:
require_version
(
"vllm>=0.4.3,<
=
0.6.
2
"
,
"To fix: pip install vllm>=0.4.3,<
=
0.6.
2
"
)
require_version
(
"vllm>=0.4.3,<0.6.
4
"
,
"To fix: pip install vllm>=0.4.3,<0.6.
4
"
)
if
finetuning_args
.
use_galore
:
require_version
(
"galore_torch"
,
"To fix: pip install galore_torch"
)
...
...
@@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise
ValueError
(
"Unsloth is incompatible with DeepSpeed ZeRO-3."
)
if
data_args
.
neat_packing
and
not
data_args
.
packing
:
logger
.
warning
(
"`neat_packing` requires `packing` is True. Change `packing` to True."
)
logger
.
warning
_rank0
(
"`neat_packing` requires `packing` is True. Change `packing` to True."
)
data_args
.
packing
=
True
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
...
...
@@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and
model_args
.
resize_vocab
and
finetuning_args
.
additional_target
is
None
):
logger
.
warning
(
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
)
logger
.
warning_rank0
(
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
)
if
training_args
.
do_train
and
model_args
.
quantization_bit
is
not
None
and
(
not
model_args
.
upcast_layernorm
):
logger
.
warning
(
"We recommend enable `upcast_layernorm` in quantized training."
)
logger
.
warning
_rank0
(
"We recommend enable `upcast_layernorm` in quantized training."
)
if
training_args
.
do_train
and
(
not
training_args
.
fp16
)
and
(
not
training_args
.
bf16
):
logger
.
warning
(
"We recommend enable mixed precision training."
)
logger
.
warning
_rank0
(
"We recommend enable mixed precision training."
)
if
training_args
.
do_train
and
finetuning_args
.
use_galore
and
not
finetuning_args
.
pure_bf16
:
logger
.
warning
(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
logger
.
warning_rank0
(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
if
(
not
training_args
.
do_train
)
and
model_args
.
quantization_bit
is
not
None
:
logger
.
warning
(
"Evaluating model in 4/8-bit mode may cause lower scores."
)
logger
.
warning
_rank0
(
"Evaluating model in 4/8-bit mode may cause lower scores."
)
if
(
not
training_args
.
do_train
)
and
finetuning_args
.
stage
==
"dpo"
and
finetuning_args
.
ref_model
is
None
:
logger
.
warning
(
"Specify `ref_model` for computing rewards at evaluation."
)
logger
.
warning
_rank0
(
"Specify `ref_model` for computing rewards at evaluation."
)
# Post-process training arguments
if
(
...
...
@@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
and
training_args
.
ddp_find_unused_parameters
is
None
and
finetuning_args
.
finetuning_type
==
"lora"
):
logger
.
warning
(
"`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training."
)
logger
.
warning
_rank0
(
"`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training."
)
training_args
.
ddp_find_unused_parameters
=
False
if
finetuning_args
.
stage
in
[
"rm"
,
"ppo"
]
and
finetuning_args
.
finetuning_type
in
[
"full"
,
"freeze"
]:
can_resume_from_checkpoint
=
False
if
training_args
.
resume_from_checkpoint
is
not
None
:
logger
.
warning
(
"Cannot resume from checkpoint in current stage."
)
logger
.
warning
_rank0
(
"Cannot resume from checkpoint in current stage."
)
training_args
.
resume_from_checkpoint
=
None
else
:
can_resume_from_checkpoint
=
True
...
...
@@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if
last_checkpoint
is
not
None
:
training_args
.
resume_from_checkpoint
=
last_checkpoint
logger
.
info
(
"Resuming training from {
}."
.
format
(
training_args
.
resume_from_checkpoint
)
)
logger
.
info
(
"Change `output_dir` or use `overwrite_output_dir` to avoid."
)
logger
.
info
_rank0
(
f
"Resuming training from
{
training_args
.
resume_from_checkpoint
}
."
)
logger
.
info
_rank0
(
"Change `output_dir` or use `overwrite_output_dir` to avoid."
)
if
(
finetuning_args
.
stage
in
[
"rm"
,
"ppo"
]
and
finetuning_args
.
finetuning_type
==
"lora"
and
training_args
.
resume_from_checkpoint
is
not
None
):
logger
.
warning
(
logger
.
warning
_rank0
(
"Add {} to `adapter_name_or_path` to resume training from checkpoint."
.
format
(
training_args
.
resume_from_checkpoint
)
...
...
src/llamafactory/model/adapter.py
View file @
2778a3d0
...
...
@@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
from
.model_utils.misc
import
find_all_linear_modules
,
find_expanded_modules
from
.model_utils.quantization
import
QuantizationMethod
from
.model_utils.unsloth
import
get_unsloth_peft_model
,
load_unsloth_peft_model
...
...
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
from
..hparams
import
FinetuningArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_setup_full_tuning
(
...
...
@@ -45,7 +45,7 @@ def _setup_full_tuning(
if
not
is_trainable
:
return
logger
.
info
(
"Fine-tuning method: Full"
)
logger
.
info
_rank0
(
"Fine-tuning method: Full"
)
forbidden_modules
=
get_forbidden_modules
(
model
.
config
,
finetuning_args
)
for
name
,
param
in
model
.
named_parameters
():
if
not
any
(
forbidden_module
in
name
for
forbidden_module
in
forbidden_modules
):
...
...
@@ -64,7 +64,7 @@ def _setup_freeze_tuning(
if
not
is_trainable
:
return
logger
.
info
(
"Fine-tuning method: Freeze"
)
logger
.
info
_rank0
(
"Fine-tuning method: Freeze"
)
if
hasattr
(
model
.
config
,
"text_config"
):
# composite models
config
=
getattr
(
model
.
config
,
"text_config"
)
else
:
...
...
@@ -133,7 +133,7 @@ def _setup_freeze_tuning(
else
:
param
.
requires_grad_
(
False
)
logger
.
info
(
"Set trainable layers: {}"
.
format
(
","
.
join
(
trainable_layers
)))
logger
.
info
_rank0
(
"Set trainable layers: {}"
.
format
(
","
.
join
(
trainable_layers
)))
def
_setup_lora_tuning
(
...
...
@@ -145,7 +145,7 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32
:
bool
,
)
->
"PeftModel"
:
if
is_trainable
:
logger
.
info
(
"Fine-tuning method: {}"
.
format
(
"DoRA"
if
finetuning_args
.
use_dora
else
"LoRA"
))
logger
.
info
_rank0
(
"Fine-tuning method: {}"
.
format
(
"DoRA"
if
finetuning_args
.
use_dora
else
"LoRA"
))
adapter_to_resume
=
None
...
...
@@ -182,7 +182,7 @@ def _setup_lora_tuning(
model
=
model
.
merge_and_unload
()
if
len
(
adapter_to_merge
)
>
0
:
logger
.
info
(
"Merged {
} adapter(s)."
.
format
(
len
(
adapter_to_merge
)
)
)
logger
.
info
_rank0
(
f
"Merged
{
len
(
adapter_to_merge
)
}
adapter(s)."
)
if
adapter_to_resume
is
not
None
:
# resume lora training
if
model_args
.
use_unsloth
:
...
...
@@ -190,7 +190,7 @@ def _setup_lora_tuning(
else
:
model
=
PeftModel
.
from_pretrained
(
model
,
adapter_to_resume
,
is_trainable
=
is_trainable
,
**
init_kwargs
)
logger
.
info
(
"Loaded adapter(s): {}"
.
format
(
","
.
join
(
model_args
.
adapter_name_or_path
)))
logger
.
info
_rank0
(
"Loaded adapter(s): {}"
.
format
(
","
.
join
(
model_args
.
adapter_name_or_path
)))
if
is_trainable
and
adapter_to_resume
is
None
:
# create new lora weights while training
if
len
(
finetuning_args
.
lora_target
)
==
1
and
finetuning_args
.
lora_target
[
0
]
==
"all"
:
...
...
@@ -219,7 +219,7 @@ def _setup_lora_tuning(
module_names
.
add
(
name
.
split
(
"."
)[
-
1
])
finetuning_args
.
additional_target
=
module_names
logger
.
warning
(
"Vocab has been resized, add {} to trainable params."
.
format
(
","
.
join
(
module_names
)))
logger
.
warning
_rank0
(
"Vocab has been resized, add {} to trainable params."
.
format
(
","
.
join
(
module_names
)))
peft_kwargs
=
{
"r"
:
finetuning_args
.
lora_rank
,
...
...
@@ -236,11 +236,11 @@ def _setup_lora_tuning(
else
:
if
finetuning_args
.
pissa_init
:
if
finetuning_args
.
pissa_iter
==
-
1
:
logger
.
info
(
"Using PiSSA initialization."
)
logger
.
info
_rank0
(
"Using PiSSA initialization."
)
peft_kwargs
[
"init_lora_weights"
]
=
"pissa"
else
:
logger
.
info
(
"Using PiSSA initialization with FSVD steps {
}."
.
format
(
finetuning_args
.
pissa_iter
)
)
peft_kwargs
[
"init_lora_weights"
]
=
"pissa_niter_{
}"
.
format
(
finetuning_args
.
pissa_iter
)
logger
.
info
_rank0
(
f
"Using PiSSA initialization with FSVD steps
{
finetuning_args
.
pissa_iter
}
."
)
peft_kwargs
[
"init_lora_weights"
]
=
f
"pissa_niter_
{
finetuning_args
.
pissa_iter
}
"
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
...
...
@@ -284,11 +284,11 @@ def init_adapter(
if
not
is_trainable
:
pass
elif
finetuning_args
.
pure_bf16
or
finetuning_args
.
use_badam
:
logger
.
info
(
"Pure bf16 / BAdam detected, remaining trainable params in half precision."
)
logger
.
info
_rank0
(
"Pure bf16 / BAdam detected, remaining trainable params in half precision."
)
elif
model_args
.
quantization_bit
is
None
and
(
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
()):
logger
.
info
(
"ZeRO3 / FSDP detected, remaining trainable params in float32."
)
logger
.
info
_rank0
(
"ZeRO3 / FSDP detected, remaining trainable params in float32."
)
else
:
logger
.
info
(
"Upcasting trainable params to float32."
)
logger
.
info
_rank0
(
"Upcasting trainable params to float32."
)
cast_trainable_params_to_fp32
=
True
if
finetuning_args
.
finetuning_type
==
"full"
:
...
...
@@ -300,6 +300,6 @@ def init_adapter(
config
,
model
,
model_args
,
finetuning_args
,
is_trainable
,
cast_trainable_params_to_fp32
)
else
:
raise
NotImplementedError
(
"Unknown finetuning type: {
}."
.
format
(
finetuning_args
.
finetuning_type
)
)
raise
NotImplementedError
(
f
"Unknown finetuning type:
{
finetuning_args
.
finetuning_type
}
."
)
return
model
src/llamafactory/model/loader.py
View file @
2778a3d0
...
...
@@ -18,8 +18,8 @@ import torch
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForVision2Seq
,
AutoProcessor
,
AutoTokenizer
from
trl
import
AutoModelForCausalLMWithValueHead
from
..extras
.logging
import
get_
logg
er
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_
ms
from
..extras
import
logg
ing
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_
other_hub
from
.adapter
import
init_adapter
from
.model_utils.liger_kernel
import
apply_liger_kernel
from
.model_utils.misc
import
register_autoclass
...
...
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from
..hparams
import
FinetuningArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
TokenizerModule
(
TypedDict
):
...
...
@@ -50,7 +50,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
Note: including inplace operation of model_args.
"""
skip_check_imports
()
model_args
.
model_name_or_path
=
try_download_model_from_
ms
(
model_args
)
model_args
.
model_name_or_path
=
try_download_model_from_
other_hub
(
model_args
)
return
{
"trust_remote_code"
:
True
,
"cache_dir"
:
model_args
.
cache_dir
,
...
...
@@ -90,17 +90,17 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
dict
(
additional_special_tokens
=
model_args
.
new_special_tokens
),
replace_additional_special_tokens
=
False
,
)
logger
.
info
(
"Add {} to special tokens."
.
format
(
","
.
join
(
model_args
.
new_special_tokens
)))
logger
.
info
_rank0
(
"Add {} to special tokens."
.
format
(
","
.
join
(
model_args
.
new_special_tokens
)))
if
num_added_tokens
>
0
and
not
model_args
.
resize_vocab
:
model_args
.
resize_vocab
=
True
logger
.
warning
(
"New tokens have been added, changed `resize_vocab` to True."
)
logger
.
warning
_rank0
(
"New tokens have been added, changed `resize_vocab` to True."
)
patch_tokenizer
(
tokenizer
)
try
:
processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
patch_processor
(
processor
,
config
,
tokenizer
,
model_args
)
except
Exception
as
e
:
logger
.
warnin
g
(
"Processor was not found: {}."
.
format
(
e
)
)
logger
.
debu
g
(
f
"Processor was not found:
{
e
}
."
)
processor
=
None
# Avoid load tokenizer, see:
...
...
@@ -153,8 +153,9 @@ def load_model(
load_class
=
AutoModelForVision2Seq
else
:
load_class
=
AutoModelForCausalLM
if
model_args
.
train_from_scratch
:
model
=
load_class
.
from_config
(
config
)
model
=
load_class
.
from_config
(
config
,
trust_remote_code
=
True
)
else
:
model
=
load_class
.
from_pretrained
(
**
init_kwargs
)
...
...
@@ -179,7 +180,7 @@ def load_model(
vhead_params
=
load_valuehead_params
(
vhead_path
,
model_args
)
if
vhead_params
is
not
None
:
model
.
load_state_dict
(
vhead_params
,
strict
=
False
)
logger
.
info
(
"Loaded valuehead from checkpoint: {
}"
.
format
(
vhead_path
)
)
logger
.
info
_rank0
(
f
"Loaded valuehead from checkpoint:
{
vhead_path
}
"
)
if
not
is_trainable
:
model
.
requires_grad_
(
False
)
...
...
@@ -197,9 +198,9 @@ def load_model(
trainable_params
,
all_param
,
100
*
trainable_params
/
all_param
)
else
:
param_stats
=
"all params: {
:,}"
.
format
(
all_param
)
param_stats
=
f
"all params:
{
all_param
:,
}
"
logger
.
info
(
param_stats
)
logger
.
info
_rank0
(
param_stats
)
if
model_args
.
print_param_status
:
for
name
,
param
in
model
.
named_parameters
():
...
...
src/llamafactory/model/model_utils/attention.py
View file @
2778a3d0
...
...
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from
transformers.utils
import
is_flash_attn_2_available
,
is_torch_sdpa_available
from
transformers.utils.versions
import
require_version
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
...
...
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
configure_attn_implementation
(
...
...
@@ -38,13 +38,15 @@ def configure_attn_implementation(
require_version
(
"transformers>=4.42.4"
,
"To fix: pip install transformers>=4.42.4"
)
require_version
(
"flash_attn>=2.6.3"
,
"To fix: pip install flash_attn>=2.6.3"
)
if
model_args
.
flash_attn
!=
"fa2"
:
logger
.
warning
(
"Gemma-2 should use flash attention 2, change `flash_attn` to fa2."
)
logger
.
warning
_rank0
(
"Gemma-2 should use flash attention 2, change `flash_attn` to fa2."
)
model_args
.
flash_attn
=
"fa2"
else
:
logger
.
warning
(
"FlashAttention-2 is not installed, use eager attention."
)
logger
.
warning
_rank0
(
"FlashAttention-2 is not installed, use eager attention."
)
model_args
.
flash_attn
=
"disabled"
elif
model_args
.
flash_attn
==
"sdpa"
:
logger
.
warning
(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
logger
.
warning_rank0
(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if
model_args
.
flash_attn
==
"auto"
:
return
...
...
@@ -54,18 +56,18 @@ def configure_attn_implementation(
elif
model_args
.
flash_attn
==
"sdpa"
:
if
not
is_torch_sdpa_available
():
logger
.
warning
(
"torch>=2.1.1 is required for SDPA attention."
)
logger
.
warning
_rank0
(
"torch>=2.1.1 is required for SDPA attention."
)
return
requested_attn_implementation
=
"sdpa"
elif
model_args
.
flash_attn
==
"fa2"
:
if
not
is_flash_attn_2_available
():
logger
.
warning
(
"FlashAttention-2 is not installed."
)
logger
.
warning
_rank0
(
"FlashAttention-2 is not installed."
)
return
requested_attn_implementation
=
"flash_attention_2"
else
:
raise
NotImplementedError
(
"Unknown attention type: {
}"
.
format
(
model_args
.
flash_attn
)
)
raise
NotImplementedError
(
f
"Unknown attention type:
{
model_args
.
flash_attn
}
"
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"internlm2"
:
# special case for custom models
setattr
(
config
,
"attn_implementation"
,
requested_attn_implementation
)
...
...
@@ -80,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
attn_implementation
=
getattr
(
config
,
"_attn_implementation"
,
None
)
if
attn_implementation
==
"flash_attention_2"
:
logger
.
info
(
"Using FlashAttention-2 for faster training and inference."
)
logger
.
info
_rank0
(
"Using FlashAttention-2 for faster training and inference."
)
elif
attn_implementation
==
"sdpa"
:
logger
.
info
(
"Using torch SDPA for faster training and inference."
)
logger
.
info
_rank0
(
"Using torch SDPA for faster training and inference."
)
else
:
logger
.
info
(
"Using vanilla attention implementation."
)
logger
.
info
_rank0
(
"Using vanilla attention implementation."
)
src/llamafactory/model/model_utils/checkpointing.py
View file @
2778a3d0
...
...
@@ -19,14 +19,14 @@
# limitations under the License.
import
inspect
from
functools
import
partial
,
wraps
from
functools
import
WRAPPER_ASSIGNMENTS
,
partial
,
wraps
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
import
torch
from
...extras
import
logging
from
...extras.constants
import
LAYERNORM_NAMES
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
...
...
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
get_unsloth_gradient_checkpointing_func
()
->
Callable
:
...
...
@@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
Only applies gradient checkpointing to trainable layers.
"""
@
wraps
(
gradient_checkpointing_func
)
@
wraps
(
gradient_checkpointing_func
,
assigned
=
WRAPPER_ASSIGNMENTS
+
(
"__self__"
,)
)
def
custom_gradient_checkpointing_func
(
func
:
Callable
,
*
args
:
Union
[
"torch.Tensor"
,
Any
],
**
kwargs
):
module
:
"torch.nn.Module"
=
func
.
__self__
...
...
@@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
return
gradient_checkpointing_func
(
func
,
*
args
,
**
kwargs
)
if
hasattr
(
gradient_checkpointing_func
,
"__self__"
):
# fix unsloth gc test case
custom_gradient_checkpointing_func
.
__self__
=
gradient_checkpointing_func
.
__self__
return
custom_gradient_checkpointing_func
...
...
@@ -111,7 +108,7 @@ def _gradient_checkpointing_enable(
from
torch.utils.checkpoint
import
checkpoint
if
not
self
.
supports_gradient_checkpointing
:
raise
ValueError
(
"{} does not support gradient checkpointing."
.
format
(
self
.
__class__
.
__name__
)
)
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
does not support gradient checkpointing."
)
if
gradient_checkpointing_kwargs
is
None
:
gradient_checkpointing_kwargs
=
{
"use_reentrant"
:
True
}
...
...
@@ -125,7 +122,7 @@ def _gradient_checkpointing_enable(
if
"value"
in
inspect
.
signature
(
self
.
_set_gradient_checkpointing
).
parameters
:
# old GC format
self
.
apply
(
partial
(
self
.
_set_gradient_checkpointing
,
value
=
True
))
self
.
enable_input_require_grads
()
logger
.
warning
(
"You are using the old GC format, some features (e.g. BAdam) will be invalid."
)
logger
.
warning
_once
(
"You are using the old GC format, some features (e.g. BAdam) will be invalid."
)
else
:
# have already enabled input require gradients
self
.
_set_gradient_checkpointing
(
enable
=
True
,
gradient_checkpointing_func
=
gradient_checkpointing_func
)
...
...
@@ -144,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
(3) add the upcasting of the lm_head in fp32
"""
if
model_args
.
upcast_layernorm
:
logger
.
info
(
"Upcasting layernorm weights in float32."
)
logger
.
info
_rank0
(
"Upcasting layernorm weights in float32."
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
ndim
==
1
and
any
(
ln_name
in
name
for
ln_name
in
LAYERNORM_NAMES
):
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
if
not
model_args
.
disable_gradient_checkpointing
:
if
not
getattr
(
model
,
"supports_gradient_checkpointing"
,
False
):
logger
.
warning
(
"Current model does not support gradient checkpointing."
)
logger
.
warning
_rank0
(
"Current model does not support gradient checkpointing."
)
else
:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
...
...
@@ -161,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
model
.
gradient_checkpointing_enable
=
MethodType
(
gradient_checkpointing_enable
,
model
)
model
.
gradient_checkpointing_enable
(
gradient_checkpointing_kwargs
=
{
"use_reentrant"
:
True
})
setattr
(
model
.
config
,
"use_cache"
,
False
)
# turn off when gradient checkpointing is enabled
logger
.
info
(
"Gradient checkpointing enabled."
)
logger
.
info
_rank0
(
"Gradient checkpointing enabled."
)
if
model_args
.
upcast_lmhead_output
:
output_layer
=
model
.
get_output_embeddings
()
if
isinstance
(
output_layer
,
torch
.
nn
.
Linear
)
and
output_layer
.
weight
.
dtype
!=
torch
.
float32
:
logger
.
info
(
"Upcasting lm_head outputs in float32."
)
logger
.
info
_rank0
(
"Upcasting lm_head outputs in float32."
)
output_layer
.
register_forward_hook
(
_fp32_forward_post_hook
)
src/llamafactory/model/model_utils/embedding.py
View file @
2778a3d0
...
...
@@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_noisy_mean_initialization
(
embed_weight
:
"torch.Tensor"
,
num_new_tokens
:
int
)
->
None
:
...
...
@@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
_noisy_mean_initialization
(
model
.
get_input_embeddings
().
weight
.
data
,
num_new_tokens
)
_noisy_mean_initialization
(
model
.
get_output_embeddings
().
weight
.
data
,
num_new_tokens
)
logger
.
info
(
"Resized token embeddings from {
} to {}."
.
format
(
current_embedding_size
,
new_embedding_size
)
)
logger
.
info
_rank0
(
f
"Resized token embeddings from
{
current_embedding_size
}
to
{
new_embedding_size
}
."
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment