Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
c7d1b209
Commit
c7d1b209
authored
Apr 29, 2025
by
chenych
Browse files
Update 0429
parent
c8d12c06
Changes
65
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
760 additions
and
336 deletions
+760
-336
scripts/api_example/test_toolcall.py
scripts/api_example/test_toolcall.py
+2
-2
scripts/convert_ckpt/llamafy_baichuan2.py
scripts/convert_ckpt/llamafy_baichuan2.py
+1
-1
src/llamafactory/__init__.py
src/llamafactory/__init__.py
+0
-15
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+0
-2
src/llamafactory/cli.py
src/llamafactory/cli.py
+75
-92
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+1
-1
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+335
-147
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+76
-22
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+141
-3
src/llamafactory/extras/logging.py
src/llamafactory/extras/logging.py
+1
-1
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+22
-16
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+8
-0
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+16
-3
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+33
-9
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+15
-0
src/llamafactory/model/adapter.py
src/llamafactory/model/adapter.py
+7
-3
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+10
-3
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+3
-3
src/llamafactory/model/model_utils/checkpointing.py
src/llamafactory/model/model_utils/checkpointing.py
+4
-3
src/llamafactory/model/model_utils/longlora.py
src/llamafactory/model/model_utils/longlora.py
+10
-10
No files found.
scripts/api_example/test_toolcall.py
View file @
c7d1b209
...
@@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float:
...
@@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float:
def
main
():
def
main
():
client
=
OpenAI
(
client
=
OpenAI
(
api_key
=
"{}"
.
format
(
os
.
environ
.
get
(
"API_KEY"
,
"0"
)),
api_key
=
"{}"
.
format
(
os
.
get
env
(
"API_KEY"
,
"0"
)),
base_url
=
"http://localhost:{}/v1"
.
format
(
os
.
environ
.
get
(
"API_PORT"
,
8000
)),
base_url
=
"http://localhost:{}/v1"
.
format
(
os
.
get
env
(
"API_PORT"
,
8000
)),
)
)
tools
=
[
tools
=
[
{
{
...
...
scripts/convert_ckpt/llamafy_baichuan2.py
View file @
c7d1b209
...
@@ -32,7 +32,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
...
@@ -32,7 +32,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
baichuan2_state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
baichuan2_state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
for
filepath
in
tqdm
(
os
.
listdir
(
input_dir
),
desc
=
"Load weights"
):
for
filepath
in
tqdm
(
os
.
listdir
(
input_dir
),
desc
=
"Load weights"
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
input_dir
,
filepath
))
and
filepath
.
endswith
(
".bin"
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
input_dir
,
filepath
))
and
filepath
.
endswith
(
".bin"
):
shard_weight
=
torch
.
load
(
os
.
path
.
join
(
input_dir
,
filepath
),
map_location
=
"cpu"
)
shard_weight
=
torch
.
load
(
os
.
path
.
join
(
input_dir
,
filepath
),
map_location
=
"cpu"
,
weights_only
=
True
)
baichuan2_state_dict
.
update
(
shard_weight
)
baichuan2_state_dict
.
update
(
shard_weight
)
llama_state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
llama_state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
...
...
src/llamafactory/__init__.py
View file @
c7d1b209
...
@@ -17,23 +17,8 @@ r"""Efficient fine-tuning of large language models.
...
@@ -17,23 +17,8 @@ r"""Efficient fine-tuning of large language models.
Level:
Level:
api, webui > chat, eval, train > data, model > hparams > extras
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.5.0
accelerate>=0.34.0,<=1.6.0
peft>=0.14.0,<=0.15.1
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<4.48.0
packing:
transformers>=4.43.0
Disable version checking: DISABLE_VERSION_CHECK=1
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
Use modelscope: USE_MODELSCOPE_HUB=1
...
...
src/llamafactory/chat/hf_engine.py
View file @
c7d1b209
...
@@ -25,7 +25,6 @@ from typing_extensions import override
...
@@ -25,7 +25,6 @@ from typing_extensions import override
from
..data
import
get_template_and_fix_tokenizer
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
,
EngineName
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
,
EngineName
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
from
.base_engine
import
BaseEngine
,
Response
...
@@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine):
inputs
=
inputs
,
inputs
=
inputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
generation_config
=
GenerationConfig
(
**
generating_args
),
generation_config
=
GenerationConfig
(
**
generating_args
),
logits_processor
=
get_logits_processor
(),
)
)
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
batch_ids
=
[
prompt_ids
],
processor
=
processor
)
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
batch_ids
=
[
prompt_ids
],
processor
=
processor
)
...
...
src/llamafactory/cli.py
View file @
c7d1b209
...
@@ -16,17 +16,7 @@ import os
...
@@ -16,17 +16,7 @@ import os
import
subprocess
import
subprocess
import
sys
import
sys
from
copy
import
deepcopy
from
copy
import
deepcopy
from
enum
import
Enum
,
unique
from
functools
import
partial
from
.
import
launcher
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
USAGE
=
(
USAGE
=
(
...
@@ -44,7 +34,21 @@ USAGE = (
...
@@ -44,7 +34,21 @@ USAGE = (
+
"-"
*
70
+
"-"
*
70
)
)
WELCOME
=
(
def
main
():
from
.
import
launcher
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
logger
=
logging
.
get_logger
(
__name__
)
WELCOME
=
(
"-"
*
58
"-"
*
58
+
"
\n
"
+
"
\n
"
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
...
@@ -54,40 +58,24 @@ WELCOME = (
...
@@ -54,40 +58,24 @@ WELCOME = (
+
"|
\n
"
+
"|
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"-"
*
58
+
"-"
*
58
)
)
logger
=
logging
.
get_logger
(
__name__
)
@
unique
class
Command
(
str
,
Enum
):
API
=
"api"
CHAT
=
"chat"
ENV
=
"env"
EVAL
=
"eval"
EXPORT
=
"export"
TRAIN
=
"train"
WEBDEMO
=
"webchat"
WEBUI
=
"webui"
VER
=
"version"
HELP
=
"help"
COMMAND_MAP
=
{
"api"
:
run_api
,
"chat"
:
run_chat
,
"env"
:
print_env
,
"eval"
:
run_eval
,
"export"
:
export_model
,
"train"
:
run_exp
,
"webchat"
:
run_web_demo
,
"webui"
:
run_web_ui
,
"version"
:
partial
(
print
,
WELCOME
),
"help"
:
partial
(
print
,
USAGE
),
}
def
main
():
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>=
1
else
"help"
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
!=
1
else
Command
.
HELP
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
if
command
==
Command
.
API
:
# launch distributed training
run_api
()
elif
command
==
Command
.
CHAT
:
run_chat
()
elif
command
==
Command
.
ENV
:
print_env
()
elif
command
==
Command
.
EVAL
:
run_eval
()
elif
command
==
Command
.
EXPORT
:
export_model
()
elif
command
==
Command
.
TRAIN
:
force_torchrun
=
is_env_enabled
(
"FORCE_TORCHRUN"
)
if
force_torchrun
or
(
get_device_count
()
>
1
and
not
use_ray
()):
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
...
@@ -123,19 +111,14 @@ def main():
...
@@ -123,19 +111,14 @@ def main():
check
=
True
,
check
=
True
,
)
)
sys
.
exit
(
process
.
returncode
)
sys
.
exit
(
process
.
returncode
)
else
:
elif
command
in
COMMAND_MAP
:
run_exp
()
COMMAND_MAP
[
command
]()
elif
command
==
Command
.
WEBDEMO
:
run_web_demo
()
elif
command
==
Command
.
WEBUI
:
run_web_ui
()
elif
command
==
Command
.
VER
:
print
(
WELCOME
)
elif
command
==
Command
.
HELP
:
print
(
USAGE
)
else
:
else
:
print
(
f
"Unknown command:
{
command
}
.
\n
{
USAGE
}
"
)
print
(
f
"Unknown command:
{
command
}
.
\n
{
USAGE
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
from
multiprocessing
import
freeze_support
freeze_support
()
main
()
main
()
src/llamafactory/data/collator.py
View file @
c7d1b209
...
@@ -176,7 +176,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
...
@@ -176,7 +176,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"input_ids"
:
features
[
"input_ids"
],
"input_ids"
:
features
[
"input_ids"
],
"image_grid_thw"
:
mm_inputs
.
get
(
"image_grid_thw"
),
"image_grid_thw"
:
mm_inputs
.
get
(
"image_grid_thw"
),
"video_grid_thw"
:
mm_inputs
.
get
(
"video_grid_thw"
),
"video_grid_thw"
:
mm_inputs
.
get
(
"video_grid_thw"
),
"attention_mask"
:
features
[
"attention_mask"
],
"attention_mask"
:
(
features
[
"attention_mask"
]
>=
1
).
float
()
,
}
}
if
"second_per_grid_ts"
in
mm_inputs
:
# for qwen2vl
if
"second_per_grid_ts"
in
mm_inputs
:
# for qwen2vl
rope_index_kwargs
[
"second_per_grid_ts"
]
=
mm_inputs
.
get
(
"second_per_grid_ts"
)
rope_index_kwargs
[
"second_per_grid_ts"
]
=
mm_inputs
.
get
(
"second_per_grid_ts"
)
...
...
src/llamafactory/data/mm_plugin.py
View file @
c7d1b209
This diff is collapsed.
Click to expand it.
src/llamafactory/data/template.py
View file @
c7d1b209
...
@@ -12,13 +12,13 @@
...
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing_extensions
import
override
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras
import
logging
from
..extras.misc
import
check_version
from
.data_utils
import
Role
from
.data_utils
import
Role
from
.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
from
.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
from
.mm_plugin
import
get_mm_plugin
from
.mm_plugin
import
get_mm_plugin
...
@@ -61,7 +61,7 @@ class Template:
...
@@ -61,7 +61,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
)
->
tuple
[
list
[
int
],
list
[
int
]]:
r
"""Return a single pair of token ids representing prompt and response respectively."""
r
"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
True
)
prompt_ids
=
[]
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
prompt_ids
+=
encoded_ids
...
@@ -77,7 +77,7 @@ class Template:
...
@@ -77,7 +77,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
False
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
...
@@ -111,12 +111,18 @@ class Template:
...
@@ -111,12 +111,18 @@ class Template:
return
token_ids
return
token_ids
def
_remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
_encode
(
def
_encode
(
self
,
self
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
r
"""Encode formatted inputs to pairs of token ids.
r
"""Encode formatted inputs to pairs of token ids.
...
@@ -134,14 +140,18 @@ class Template:
...
@@ -134,14 +140,18 @@ class Template:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
if
message
[
"role"
]
==
Role
.
USER
.
value
:
content
=
message
[
"content"
]
elements
+=
self
.
format_user
.
apply
(
content
=
message
[
"content"
],
idx
=
str
(
i
//
2
))
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
elif
message
[
"role"
]
==
Role
.
ASSISTANT
.
value
:
content
=
self
.
_remove_thought
(
content
)
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
OBSERVATION
.
value
:
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elements
+=
self
.
format_user
.
apply
(
content
=
content
,
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
FUNCTION
.
value
:
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
else
:
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
@@ -318,6 +328,7 @@ class Llama2Template(Template):
...
@@ -318,6 +328,7 @@ class Llama2Template(Template):
messages
:
list
[
dict
[
str
,
str
]],
messages
:
list
[
dict
[
str
,
str
]],
system
:
str
,
system
:
str
,
tools
:
str
,
tools
:
str
,
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
)
->
list
[
list
[
int
]]:
system
=
system
or
self
.
default_system
system
=
system
or
self
.
default_system
encoded_messages
=
[]
encoded_messages
=
[]
...
@@ -331,14 +342,18 @@ class Llama2Template(Template):
...
@@ -331,14 +342,18 @@ class Llama2Template(Template):
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
if
message
[
"role"
]
==
Role
.
USER
.
value
:
content
=
message
[
"content"
]
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
message
[
"content"
])
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
elif
message
[
"role"
]
==
Role
.
ASSISTANT
.
value
:
content
=
self
.
_remove_thought
(
content
)
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
OBSERVATION
.
value
:
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
content
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
.
value
:
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
else
:
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
@@ -477,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
...
@@ -477,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
assistant_slot
=
assistant_slot
.
replace
(
"<think>"
,
""
).
replace
(
"</think>"
,
""
).
lstrip
(
"
\n
"
)
# remove thought tags
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
default_system
=
find_diff
(
user_slot_empty_system
,
user_slot
)
default_system
=
find_diff
(
user_slot_empty_system
,
user_slot
)
...
@@ -518,9 +534,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
...
@@ -518,9 +534,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template
=
TEMPLATES
[
data_args
.
template
]
template
=
TEMPLATES
[
data_args
.
template
]
if
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
check_version
(
"transformers>=4.45.0"
)
if
data_args
.
train_on_prompt
and
template
.
efficient_eos
:
if
data_args
.
train_on_prompt
and
template
.
efficient_eos
:
raise
ValueError
(
"Current template does not support `train_on_prompt`."
)
raise
ValueError
(
"Current template does not support `train_on_prompt`."
)
...
@@ -871,6 +884,18 @@ register_template(
...
@@ -871,6 +884,18 @@ register_template(
)
)
register_template
(
name
=
"granite3_vision"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}
\n
<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}
\n
"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
register_template
(
register_template
(
name
=
"index"
,
name
=
"index"
,
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
...
@@ -923,6 +948,20 @@ register_template(
...
@@ -923,6 +948,20 @@ register_template(
)
)
register_template
(
name
=
"intern_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
),
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"intern_vl"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
register_template
(
register_template
(
name
=
"kimi_vl"
,
name
=
"kimi_vl"
,
format_user
=
StringFormatter
(
format_user
=
StringFormatter
(
...
@@ -1389,6 +1428,21 @@ register_template(
...
@@ -1389,6 +1428,21 @@ register_template(
)
)
# copied from qwen template
register_template
(
name
=
"qwen3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
)
# copied from chatml template
# copied from chatml template
register_template
(
register_template
(
name
=
"qwen2_audio"
,
name
=
"qwen2_audio"
,
...
...
src/llamafactory/extras/constants.py
View file @
c7d1b209
...
@@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
...
@@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
AUDIO_PLACEHOLDER
=
os
.
environ
.
get
(
"AUDIO_PLACEHOLDER"
,
"<audio>"
)
AUDIO_PLACEHOLDER
=
os
.
get
env
(
"AUDIO_PLACEHOLDER"
,
"<audio>"
)
CHECKPOINT_NAMES
=
{
CHECKPOINT_NAMES
=
{
SAFE_ADAPTER_WEIGHTS_NAME
,
SAFE_ADAPTER_WEIGHTS_NAME
,
...
@@ -50,7 +50,7 @@ FILEEXT2TYPE = {
...
@@ -50,7 +50,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX
=
-
100
IGNORE_INDEX
=
-
100
IMAGE_PLACEHOLDER
=
os
.
environ
.
get
(
"IMAGE_PLACEHOLDER"
,
"<image>"
)
IMAGE_PLACEHOLDER
=
os
.
get
env
(
"IMAGE_PLACEHOLDER"
,
"<image>"
)
LAYERNORM_NAMES
=
{
"norm"
,
"ln"
}
LAYERNORM_NAMES
=
{
"norm"
,
"ln"
}
...
@@ -89,7 +89,7 @@ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
...
@@ -89,7 +89,7 @@ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SWANLAB_CONFIG
=
"swanlab_public_config.json"
SWANLAB_CONFIG
=
"swanlab_public_config.json"
VIDEO_PLACEHOLDER
=
os
.
environ
.
get
(
"VIDEO_PLACEHOLDER"
,
"<video>"
)
VIDEO_PLACEHOLDER
=
os
.
get
env
(
"VIDEO_PLACEHOLDER"
,
"<video>"
)
V_HEAD_WEIGHTS_NAME
=
"value_head.bin"
V_HEAD_WEIGHTS_NAME
=
"value_head.bin"
...
@@ -838,11 +838,46 @@ register_model_group(
...
@@ -838,11 +838,46 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-8b-instruct"
,
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-8b-instruct"
,
},
},
"Granite-3.2-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.2-2b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.2-2b-instruct"
,
},
"Granite-3.2-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.2-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.2-8b-instruct"
,
},
"Granite-3.3-2B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.3-2b-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.3-2b-base"
,
},
"Granite-3.3-8B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.3-8b-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.3-8b-base"
,
},
"Granite-3.3-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.3-2b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.3-2b-instruct"
,
},
"Granite-3.3-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.3-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.3-8b-instruct"
,
},
},
},
template
=
"granite3"
,
template
=
"granite3"
,
)
)
register_model_group
(
models
=
{
"Granite-3.2-1B-A400M-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-vision-3.2-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-vision-3.2-2b"
,
},
},
template
=
"granite3_vision"
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Hunyuan-7B-Instruct"
:
{
"Hunyuan-7B-Instruct"
:
{
...
@@ -965,6 +1000,46 @@ register_model_group(
...
@@ -965,6 +1000,46 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"InternVL2.5-2B-MPO"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL2_5-2B-MPO-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL2_5-2B-MPO-hf"
,
},
"InternVL2.5-8B-MPO"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL2_5-8B-MPO-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL2_5-8B-MPO-hf"
,
},
"InternVL3-1B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-1B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-1B-hf"
,
},
"InternVL3-2B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-2B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-2B-hf"
,
},
"InternVL3-8B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-8B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-8B-hf"
,
},
"InternVL3-14B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-14B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-14B-hf"
,
},
"InternVL3-38B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-38B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-38B-hf"
,
},
"InternVL3-78B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-78B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-78B-hf"
,
},
},
template
=
"intern_vl"
,
multimodal
=
True
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Jamba-v0.1"
:
{
"Jamba-v0.1"
:
{
...
@@ -2328,6 +2403,69 @@ register_model_group(
...
@@ -2328,6 +2403,69 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"Qwen3-0.6B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B-Base"
,
},
"Qwen3-1.7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B-Base"
,
},
"Qwen3-4B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-Base"
,
},
"Qwen3-8B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B-Base"
,
},
"Qwen3-14B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B-Base"
,
},
"Qwen3-30B-A3B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-Base"
,
},
"Qwen3-0.6B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B"
,
},
"Qwen3-1.7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B"
,
},
"Qwen3-4B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B"
,
},
"Qwen3-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B"
,
},
"Qwen3-14B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B"
,
},
"Qwen3-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-32B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-32B"
,
},
"Qwen3-30B-A3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B"
,
},
"Qwen3-235B-A22B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
},
},
template
=
"qwen3"
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"Qwen2-Audio-7B"
:
{
"Qwen2-Audio-7B"
:
{
...
...
src/llamafactory/extras/logging.py
View file @
c7d1b209
...
@@ -79,7 +79,7 @@ class _Logger(logging.Logger):
...
@@ -79,7 +79,7 @@ class _Logger(logging.Logger):
def
_get_default_logging_level
()
->
"logging._Level"
:
def
_get_default_logging_level
()
->
"logging._Level"
:
r
"""Return the default logging level."""
r
"""Return the default logging level."""
env_level_str
=
os
.
environ
.
get
(
"LLAMAFACTORY_VERBOSITY"
,
None
)
env_level_str
=
os
.
get
env
(
"LLAMAFACTORY_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
:
if
env_level_str
.
upper
()
in
logging
.
_nameToLevel
:
if
env_level_str
.
upper
()
in
logging
.
_nameToLevel
:
return
logging
.
_nameToLevel
[
env_level_str
.
upper
()]
return
logging
.
_nameToLevel
[
env_level_str
.
upper
()]
...
...
src/llamafactory/extras/misc.py
View file @
c7d1b209
...
@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
...
@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def
check_dependencies
()
->
None
:
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.4
1.2
,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"transformers>=4.4
5.0
,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"datasets>=2.16.0,<=3.5.0"
)
check_version
(
"datasets>=2.16.0,<=3.5.0"
)
check_version
(
"accelerate>=0.34.0,<=1.6.0"
)
check_version
(
"accelerate>=0.34.0,<=1.6.0"
)
check_version
(
"peft>=0.14.0,<=0.15.1"
)
check_version
(
"peft>=0.14.0,<=0.15.1"
)
...
@@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
...
@@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
def
get_current_device
()
->
"torch.device"
:
def
get_current_device
()
->
"torch.device"
:
r
"""Get the current available device."""
r
"""Get the current available device."""
if
is_torch_xpu_available
():
if
is_torch_xpu_available
():
device
=
"xpu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"xpu:{}"
.
format
(
os
.
get
env
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_npu_available
():
elif
is_torch_npu_available
():
device
=
"npu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"npu:{}"
.
format
(
os
.
get
env
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_mps_available
():
elif
is_torch_mps_available
():
device
=
"mps:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"mps:{}"
.
format
(
os
.
get
env
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_cuda_available
():
elif
is_torch_cuda_available
():
device
=
"cuda:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"cuda:{}"
.
format
(
os
.
get
env
(
"LOCAL_RANK"
,
"0"
))
else
:
else
:
device
=
"cpu"
device
=
"cpu"
...
@@ -155,11 +155,13 @@ def get_current_device() -> "torch.device":
...
@@ -155,11 +155,13 @@ def get_current_device() -> "torch.device":
def
get_device_count
()
->
int
:
def
get_device_count
()
->
int
:
r
"""Get the number of available
GPU or NPU
devices."""
r
"""Get the number of available devices."""
if
is_torch_xpu_available
():
if
is_torch_xpu_available
():
return
torch
.
xpu
.
device_count
()
return
torch
.
xpu
.
device_count
()
elif
is_torch_npu_available
():
elif
is_torch_npu_available
():
return
torch
.
npu
.
device_count
()
return
torch
.
npu
.
device_count
()
elif
is_torch_mps_available
():
return
torch
.
mps
.
device_count
()
elif
is_torch_cuda_available
():
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
device_count
()
return
torch
.
cuda
.
device_count
()
else
:
else
:
...
@@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList":
...
@@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList":
def
get_peak_memory
()
->
tuple
[
int
,
int
]:
def
get_peak_memory
()
->
tuple
[
int
,
int
]:
r
"""Get the peak memory usage for the current device (in Bytes)."""
r
"""Get the peak memory usage for the current device (in Bytes)."""
if
is_torch_npu_available
():
if
is_torch_xpu_available
():
return
torch
.
npu
.
max_memory_allocated
(),
torch
.
npu
.
max_memory_reserved
()
elif
is_torch_xpu_available
():
return
torch
.
xpu
.
max_memory_allocated
(),
torch
.
xpu
.
max_memory_reserved
()
return
torch
.
xpu
.
max_memory_allocated
(),
torch
.
xpu
.
max_memory_reserved
()
elif
is_torch_npu_available
():
return
torch
.
npu
.
max_memory_allocated
(),
torch
.
npu
.
max_memory_reserved
()
elif
is_torch_mps_available
():
return
torch
.
mps
.
current_allocated_memory
(),
-
1
elif
is_torch_cuda_available
():
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
max_memory_allocated
(),
torch
.
cuda
.
max_memory_reserved
()
return
torch
.
cuda
.
max_memory_allocated
(),
torch
.
cuda
.
max_memory_reserved
()
else
:
else
:
...
@@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
...
@@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
return
torch
.
float32
return
torch
.
float32
def
is_gpu_or_npu_available
()
->
bool
:
def
is_accelerator_available
()
->
bool
:
r
"""Check if the GPU or NPU is available."""
r
"""Check if the accelerator is available."""
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
or
is_torch_xpu_available
()
return
(
is_torch_xpu_available
()
or
is_torch_npu_available
()
or
is_torch_mps_available
()
or
is_torch_cuda_available
()
)
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
...
@@ -229,7 +235,7 @@ def skip_check_imports() -> None:
...
@@ -229,7 +235,7 @@ def skip_check_imports() -> None:
def
torch_gc
()
->
None
:
def
torch_gc
()
->
None
:
r
"""Collect
GPU or NPU
memory."""
r
"""Collect
the device
memory."""
gc
.
collect
()
gc
.
collect
()
if
is_torch_xpu_available
():
if
is_torch_xpu_available
():
torch
.
xpu
.
empty_cache
()
torch
.
xpu
.
empty_cache
()
...
@@ -280,7 +286,7 @@ def use_ray() -> bool:
...
@@ -280,7 +286,7 @@ def use_ray() -> bool:
def
find_available_port
()
->
int
:
def
find_available_port
()
->
int
:
"""Find an available port on the local machine."""
r
"""Find an available port on the local machine."""
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
bind
((
""
,
0
))
sock
.
bind
((
""
,
0
))
port
=
sock
.
getsockname
()[
1
]
port
=
sock
.
getsockname
()[
1
]
...
@@ -288,8 +294,8 @@ def find_available_port() -> int:
...
@@ -288,8 +294,8 @@ def find_available_port() -> int:
return
port
return
port
def
fix_proxy
(
ipv6_enabled
:
bool
)
->
None
:
def
fix_proxy
(
ipv6_enabled
:
bool
=
False
)
->
None
:
"""Fix proxy settings for gradio ui."""
r
"""Fix proxy settings for gradio ui."""
os
.
environ
[
"no_proxy"
]
=
"localhost,127.0.0.1,0.0.0.0"
os
.
environ
[
"no_proxy"
]
=
"localhost,127.0.0.1,0.0.0.0"
if
ipv6_enabled
:
if
ipv6_enabled
:
for
name
in
(
"http_proxy"
,
"https_proxy"
,
"HTTP_PROXY"
,
"HTTPS_PROXY"
):
for
name
in
(
"http_proxy"
,
"https_proxy"
,
"HTTP_PROXY"
,
"HTTPS_PROXY"
):
...
...
src/llamafactory/hparams/finetuning_args.py
View file @
c7d1b209
...
@@ -411,6 +411,10 @@ class FinetuningArguments(
...
@@ -411,6 +411,10 @@ class FinetuningArguments(
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the Adam-mini optimizer."
},
metadata
=
{
"help"
:
"Whether or not to use the Adam-mini optimizer."
},
)
)
use_muon
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the Muon optimizer."
},
)
freeze_vision_tower
:
bool
=
field
(
freeze_vision_tower
:
bool
=
field
(
default
=
True
,
default
=
True
,
metadata
=
{
"help"
:
"Whether ot not to freeze the vision tower in MLLM training."
},
metadata
=
{
"help"
:
"Whether ot not to freeze the vision tower in MLLM training."
},
...
@@ -431,6 +435,10 @@ class FinetuningArguments(
...
@@ -431,6 +435,10 @@ class FinetuningArguments(
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to disable the shuffling of the training set."
},
metadata
=
{
"help"
:
"Whether or not to disable the shuffling of the training set."
},
)
)
early_stopping_steps
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Number of steps to stop training if the `metric_for_best_model` does not improve."
},
)
plot_loss
:
bool
=
field
(
plot_loss
:
bool
=
field
(
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to save the training loss curves."
},
metadata
=
{
"help"
:
"Whether or not to save the training loss curves."
},
...
...
src/llamafactory/hparams/model_args.py
View file @
c7d1b209
...
@@ -65,7 +65,13 @@ class BaseModelArguments:
...
@@ -65,7 +65,13 @@ class BaseModelArguments:
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether or not the special tokens should be split during the tokenization process."
},
metadata
=
{
"help"
:
"Whether or not the special tokens should be split during the tokenization process."
},
)
)
new_special_tokens
:
Optional
[
str
]
=
field
(
add_tokens
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
add_special_tokens
:
Optional
[
str
]
=
field
(
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
metadata
=
{
"help"
:
"Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
)
...
@@ -176,8 +182,11 @@ class BaseModelArguments:
...
@@ -176,8 +182,11 @@ class BaseModelArguments:
if
self
.
adapter_name_or_path
is
not
None
:
# support merging multiple lora weights
if
self
.
adapter_name_or_path
is
not
None
:
# support merging multiple lora weights
self
.
adapter_name_or_path
=
[
path
.
strip
()
for
path
in
self
.
adapter_name_or_path
.
split
(
","
)]
self
.
adapter_name_or_path
=
[
path
.
strip
()
for
path
in
self
.
adapter_name_or_path
.
split
(
","
)]
if
self
.
new_special_tokens
is
not
None
:
# support multiple special tokens
if
self
.
add_tokens
is
not
None
:
# support multiple tokens
self
.
new_special_tokens
=
[
token
.
strip
()
for
token
in
self
.
new_special_tokens
.
split
(
","
)]
self
.
add_tokens
=
[
token
.
strip
()
for
token
in
self
.
add_tokens
.
split
(
","
)]
if
self
.
add_special_tokens
is
not
None
:
# support multiple special tokens
self
.
add_special_tokens
=
[
token
.
strip
()
for
token
in
self
.
add_special_tokens
.
split
(
","
)]
@
dataclass
@
dataclass
...
@@ -222,6 +231,10 @@ class ProcessorArguments:
...
@@ -222,6 +231,10 @@ class ProcessorArguments:
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Use pan and scan to process image for gemma3."
},
metadata
=
{
"help"
:
"Use pan and scan to process image for gemma3."
},
)
)
crop_to_patches
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to crop the image to patches for internvl."
},
)
use_audio_in_video
:
bool
=
field
(
use_audio_in_video
:
bool
=
field
(
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
...
...
src/llamafactory/hparams/parser.py
View file @
c7d1b209
...
@@ -24,6 +24,7 @@ from typing import Any, Optional, Union
...
@@ -24,6 +24,7 @@ from typing import Any, Optional, Union
import
torch
import
torch
import
transformers
import
transformers
import
yaml
import
yaml
from
omegaconf
import
OmegaConf
from
transformers
import
HfArgumentParser
from
transformers
import
HfArgumentParser
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.trainer_utils
import
get_last_checkpoint
from
transformers.trainer_utils
import
get_last_checkpoint
...
@@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
...
@@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if
args
is
not
None
:
if
args
is
not
None
:
return
args
return
args
if
len
(
sys
.
argv
)
==
2
and
(
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
)):
if
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
):
return
yaml
.
safe_load
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
elif
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
".json"
):
dict_config
=
yaml
.
safe_load
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
return
json
.
loads
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
elif
sys
.
argv
[
1
].
endswith
(
".json"
):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
json
.
loads
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
else
:
else
:
return
sys
.
argv
[
1
:]
return
sys
.
argv
[
1
:]
...
@@ -91,6 +96,14 @@ def _set_transformers_logging() -> None:
...
@@ -91,6 +96,14 @@ def _set_transformers_logging() -> None:
transformers
.
utils
.
logging
.
enable_explicit_format
()
transformers
.
utils
.
logging
.
enable_explicit_format
()
def
_set_env_vars
()
->
None
:
if
is_torch_npu_available
():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch
.
npu
.
set_compile_mode
(
jit_compile
=
is_env_enabled
(
"NPU_JIT_COMPILE"
))
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
_verify_model_args
(
def
_verify_model_args
(
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
...
@@ -279,12 +292,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
...
@@ -279,12 +292,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if
training_args
.
deepspeed
is
not
None
and
(
finetuning_args
.
use_galore
or
finetuning_args
.
use_apollo
):
if
training_args
.
deepspeed
is
not
None
and
(
finetuning_args
.
use_galore
or
finetuning_args
.
use_apollo
):
raise
ValueError
(
"GaLore and APOLLO are incompatible with DeepSpeed yet."
)
raise
ValueError
(
"GaLore and APOLLO are incompatible with DeepSpeed yet."
)
if
model_args
.
infer_backend
=
=
"vllm"
:
if
model_args
.
infer_backend
!
=
EngineName
.
HF
:
raise
ValueError
(
"vLLM backend is only available for API, CLI and Web."
)
raise
ValueError
(
"vLLM
/SGLang
backend is only available for API, CLI and Web."
)
if
model_args
.
use_unsloth
and
is_deepspeed_zero3_enabled
():
if
model_args
.
use_unsloth
and
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"Unsloth is incompatible with DeepSpeed ZeRO-3."
)
raise
ValueError
(
"Unsloth is incompatible with DeepSpeed ZeRO-3."
)
_set_env_vars
()
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
,
training_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
,
training_args
)
...
@@ -321,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
...
@@ -321,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
logger
.
warning_rank0
(
"Specify `ref_model` for computing rewards at evaluation."
)
logger
.
warning_rank0
(
"Specify `ref_model` for computing rewards at evaluation."
)
# Post-process training arguments
# Post-process training arguments
training_args
.
generation_max_length
=
training_args
.
generation_max_length
or
data_args
.
cutoff_len
training_args
.
generation_num_beams
=
data_args
.
eval_num_beams
or
training_args
.
generation_num_beams
training_args
.
remove_unused_columns
=
False
# important for multimodal dataset
if
finetuning_args
.
finetuning_type
==
"lora"
:
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args
.
label_names
=
training_args
.
label_names
or
[
"labels"
]
if
(
if
(
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
and
training_args
.
ddp_find_unused_parameters
is
None
and
training_args
.
ddp_find_unused_parameters
is
None
and
finetuning_args
.
finetuning_type
==
"lora"
and
finetuning_args
.
finetuning_type
==
"lora"
):
):
logger
.
warning
_rank0
(
"`ddp_find_unused_parameters`
needs to be set as False for LoRA in DDP training
."
)
logger
.
info
_rank0
(
"
Set
`ddp_find_unused_parameters`
to False in DDP training since LoRA is enabled
."
)
training_args
.
ddp_find_unused_parameters
=
False
training_args
.
ddp_find_unused_parameters
=
False
if
finetuning_args
.
stage
in
[
"rm"
,
"ppo"
]
and
finetuning_args
.
finetuning_type
in
[
"full"
,
"freeze"
]:
if
finetuning_args
.
stage
in
[
"rm"
,
"ppo"
]
and
finetuning_args
.
finetuning_type
in
[
"full"
,
"freeze"
]:
...
@@ -407,6 +429,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
...
@@ -407,6 +429,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if
model_args
.
adapter_name_or_path
is
not
None
and
len
(
model_args
.
adapter_name_or_path
)
!=
1
:
if
model_args
.
adapter_name_or_path
is
not
None
and
len
(
model_args
.
adapter_name_or_path
)
!=
1
:
raise
ValueError
(
"vLLM only accepts a single adapter. Merge them first."
)
raise
ValueError
(
"vLLM only accepts a single adapter. Merge them first."
)
_set_env_vars
()
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
)
...
@@ -428,9 +451,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
...
@@ -428,9 +451,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
_set_transformers_logging
()
_set_transformers_logging
()
# Check arguments
# Check arguments
if
model_args
.
infer_backend
=
=
"vllm"
:
if
model_args
.
infer_backend
!
=
EngineName
.
HF
:
raise
ValueError
(
"vLLM backend is only available for API, CLI and Web."
)
raise
ValueError
(
"vLLM
/SGLang
backend is only available for API, CLI and Web."
)
_set_env_vars
()
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
)
...
...
src/llamafactory/hparams/training_args.py
View file @
c7d1b209
...
@@ -34,6 +34,10 @@ class RayArguments:
...
@@ -34,6 +34,10 @@ class RayArguments:
default
=
"./saves"
,
default
=
"./saves"
,
metadata
=
{
"help"
:
"The storage path to save training results to"
},
metadata
=
{
"help"
:
"The storage path to save training results to"
},
)
)
ray_storage_filesystem
:
Optional
[
Literal
[
"s3"
,
"gs"
,
"gcs"
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The storage filesystem to use. If None specified, local filesystem will be used."
},
)
ray_num_workers
:
int
=
field
(
ray_num_workers
:
int
=
field
(
default
=
1
,
default
=
1
,
metadata
=
{
"help"
:
"The number of workers for Ray training. Default is 1 worker."
},
metadata
=
{
"help"
:
"The number of workers for Ray training. Default is 1 worker."
},
...
@@ -55,6 +59,17 @@ class RayArguments:
...
@@ -55,6 +59,17 @@ class RayArguments:
self
.
use_ray
=
use_ray
()
self
.
use_ray
=
use_ray
()
if
isinstance
(
self
.
resources_per_worker
,
str
)
and
self
.
resources_per_worker
.
startswith
(
"{"
):
if
isinstance
(
self
.
resources_per_worker
,
str
)
and
self
.
resources_per_worker
.
startswith
(
"{"
):
self
.
resources_per_worker
=
_convert_str_dict
(
json
.
loads
(
self
.
resources_per_worker
))
self
.
resources_per_worker
=
_convert_str_dict
(
json
.
loads
(
self
.
resources_per_worker
))
if
self
.
ray_storage_filesystem
is
not
None
:
if
self
.
ray_storage_filesystem
not
in
[
"s3"
,
"gs"
,
"gcs"
]:
raise
ValueError
(
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
)
import
pyarrow.fs
as
fs
if
self
.
ray_storage_filesystem
==
"s3"
:
self
.
ray_storage_filesystem
=
fs
.
S3FileSystem
()
elif
self
.
ray_storage_filesystem
==
"gs"
or
self
.
ray_storage_filesystem
==
"gcs"
:
self
.
ray_storage_filesystem
=
fs
.
GcsFileSystem
()
@
dataclass
@
dataclass
...
...
src/llamafactory/model/adapter.py
View file @
c7d1b209
...
@@ -23,7 +23,7 @@ from ..extras import logging
...
@@ -23,7 +23,7 @@ from ..extras import logging
from
.model_utils.misc
import
find_all_linear_modules
,
find_expanded_modules
from
.model_utils.misc
import
find_all_linear_modules
,
find_expanded_modules
from
.model_utils.quantization
import
QuantizationMethod
from
.model_utils.quantization
import
QuantizationMethod
from
.model_utils.unsloth
import
get_unsloth_peft_model
,
load_unsloth_peft_model
from
.model_utils.unsloth
import
get_unsloth_peft_model
,
load_unsloth_peft_model
from
.model_utils.visual
import
get_forbidden_modules
,
patch_target_modules
from
.model_utils.visual
import
COMPOSITE_MODELS
,
get_forbidden_modules
,
patch_target_modules
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -100,7 +100,7 @@ def _setup_freeze_tuning(
...
@@ -100,7 +100,7 @@ def _setup_freeze_tuning(
hidden_modules
.
add
(
name
.
split
(
".1."
)[
-
1
].
split
(
"."
)[
0
])
hidden_modules
.
add
(
name
.
split
(
".1."
)[
-
1
].
split
(
"."
)[
0
])
if
re
.
search
(
r
"\.\d+\."
,
name
)
is
None
:
if
re
.
search
(
r
"\.\d+\."
,
name
)
is
None
:
non_hidden_modules
.
add
(
name
.
split
(
"."
)[
-
2
])
non_hidden_modules
.
add
(
name
.
split
(
"."
)[
-
2
])
# remove weight/bias
trainable_layers
=
[]
trainable_layers
=
[]
for
module_name
in
finetuning_args
.
freeze_trainable_modules
:
for
module_name
in
finetuning_args
.
freeze_trainable_modules
:
...
@@ -121,6 +121,10 @@ def _setup_freeze_tuning(
...
@@ -121,6 +121,10 @@ def _setup_freeze_tuning(
trainable_layers
.
append
(
module_name
)
trainable_layers
.
append
(
module_name
)
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
not
finetuning_args
.
freeze_multi_modal_projector
and
model_type
in
COMPOSITE_MODELS
:
trainable_layers
.
append
(
COMPOSITE_MODELS
[
model_type
].
projector_key
)
forbidden_modules
=
get_forbidden_modules
(
model
.
config
,
finetuning_args
)
forbidden_modules
=
get_forbidden_modules
(
model
.
config
,
finetuning_args
)
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
any
(
trainable_layer
in
name
for
trainable_layer
in
trainable_layers
)
and
not
any
(
if
any
(
trainable_layer
in
name
for
trainable_layer
in
trainable_layers
)
and
not
any
(
...
@@ -204,7 +208,7 @@ def _setup_lora_tuning(
...
@@ -204,7 +208,7 @@ def _setup_lora_tuning(
if
(
if
(
finetuning_args
.
use_dora
finetuning_args
.
use_dora
and
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
and
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
and
getattr
(
model
,
"quantization_method"
,
None
)
!=
QuantizationMethod
.
B
ITS_AND_BYTES
and
getattr
(
model
,
"quantization_method"
,
None
)
!=
QuantizationMethod
.
B
NB
):
):
raise
ValueError
(
"DoRA is not compatible with PTQ-quantized models."
)
raise
ValueError
(
"DoRA is not compatible with PTQ-quantized models."
)
...
...
src/llamafactory/model/loader.py
View file @
c7d1b209
...
@@ -19,7 +19,6 @@ import torch
...
@@ -19,7 +19,6 @@ import torch
from
transformers
import
(
from
transformers
import
(
AutoConfig
,
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForCausalLM
,
AutoModelForImageTextToText
,
AutoModelForSeq2SeqLM
,
AutoModelForSeq2SeqLM
,
AutoModelForTextToWaveform
,
AutoModelForTextToWaveform
,
AutoModelForVision2Seq
,
AutoModelForVision2Seq
,
...
@@ -30,6 +29,7 @@ from trl import AutoModelForCausalLMWithValueHead
...
@@ -30,6 +29,7 @@ from trl import AutoModelForCausalLMWithValueHead
from
..extras
import
logging
from
..extras
import
logging
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_other_hub
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_other_hub
from
..extras.packages
import
is_transformers_version_greater_than
from
.adapter
import
init_adapter
from
.adapter
import
init_adapter
from
.model_utils.liger_kernel
import
apply_liger_kernel
from
.model_utils.liger_kernel
import
apply_liger_kernel
from
.model_utils.misc
import
register_autoclass
from
.model_utils.misc
import
register_autoclass
...
@@ -39,6 +39,10 @@ from .model_utils.valuehead import load_valuehead_params
...
@@ -39,6 +39,10 @@ from .model_utils.valuehead import load_valuehead_params
from
.patcher
import
patch_config
,
patch_model
,
patch_processor
,
patch_tokenizer
,
patch_valuehead_model
from
.patcher
import
patch_config
,
patch_model
,
patch_processor
,
patch_tokenizer
,
patch_valuehead_model
if
is_transformers_version_greater_than
(
"4.46.0"
):
from
transformers
import
AutoModelForImageTextToText
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
...
@@ -97,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
...
@@ -97,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
patch_processor
(
processor
,
tokenizer
,
model_args
)
patch_processor
(
processor
,
tokenizer
,
model_args
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
debug
(
f
"Failed to load processor:
{
e
}
."
)
logger
.
info_rank0
(
f
"Failed to load processor:
{
e
}
."
)
processor
=
None
processor
=
None
# Avoid load tokenizer, see:
# Avoid load tokenizer, see:
...
@@ -145,7 +149,10 @@ def load_model(
...
@@ -145,7 +149,10 @@ def load_model(
else
:
else
:
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
# image-text
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForVision2Seq
load_class
=
AutoModelForVision2Seq
elif
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
# image-text
elif
(
is_transformers_version_greater_than
(
"4.46.0"
)
and
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
()
):
# image-text
load_class
=
AutoModelForImageTextToText
load_class
=
AutoModelForImageTextToText
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
load_class
=
AutoModelForSeq2SeqLM
load_class
=
AutoModelForSeq2SeqLM
...
...
src/llamafactory/model/model_utils/attention.py
View file @
c7d1b209
...
@@ -18,7 +18,6 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl
...
@@ -18,7 +18,6 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl
from
...extras
import
logging
from
...extras
import
logging
from
...extras.constants
import
AttentionFunction
from
...extras.constants
import
AttentionFunction
from
...extras.misc
import
check_version
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -36,8 +35,6 @@ def configure_attn_implementation(
...
@@ -36,8 +35,6 @@ def configure_attn_implementation(
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
is_flash_attn_2_available
():
if
is_flash_attn_2_available
():
check_version
(
"transformers>=4.42.4"
)
check_version
(
"flash_attn>=2.6.3"
)
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
logger
.
warning_rank0
(
"Gemma 2 should use flash attention 2, change `flash_attn` to fa2."
)
logger
.
warning_rank0
(
"Gemma 2 should use flash attention 2, change `flash_attn` to fa2."
)
model_args
.
flash_attn
=
AttentionFunction
.
FA2
model_args
.
flash_attn
=
AttentionFunction
.
FA2
...
@@ -72,6 +69,9 @@ def configure_attn_implementation(
...
@@ -72,6 +69,9 @@ def configure_attn_implementation(
if
getattr
(
config
,
"model_type"
,
None
)
==
"internlm2"
:
# special case for custom models
if
getattr
(
config
,
"model_type"
,
None
)
==
"internlm2"
:
# special case for custom models
setattr
(
config
,
"attn_implementation"
,
requested_attn_implementation
)
setattr
(
config
,
"attn_implementation"
,
requested_attn_implementation
)
elif
getattr
(
config
,
"model_type"
,
None
)
==
"kimi_vl"
:
setattr
(
config
.
vision_config
,
"_attn_implementation"
,
requested_attn_implementation
)
setattr
(
config
.
text_config
,
"_attn_implementation"
,
requested_attn_implementation
)
else
:
else
:
setattr
(
config
,
"_attn_implementation"
,
requested_attn_implementation
)
setattr
(
config
,
"_attn_implementation"
,
requested_attn_implementation
)
...
...
src/llamafactory/model/model_utils/checkpointing.py
View file @
c7d1b209
...
@@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
...
@@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
)
->
"torch.Tensor"
:
)
->
"torch.Tensor"
:
saved_hidden_states
=
hidden_states
.
to
(
"cpu"
,
non_blocking
=
True
)
saved_hidden_states
=
hidden_states
.
to
(
"cpu"
,
non_blocking
=
True
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
forward_function
(
hidden_states
,
*
args
)
output
s
=
forward_function
(
hidden_states
,
*
args
)
ctx
.
save_for_backward
(
saved_hidden_states
)
ctx
.
save_for_backward
(
saved_hidden_states
)
ctx
.
forward_function
=
forward_function
ctx
.
forward_function
=
forward_function
ctx
.
args
=
args
ctx
.
args
=
args
return
output
return
output
s
@
staticmethod
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_bwd
@
torch
.
cuda
.
amp
.
custom_bwd
...
@@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
...
@@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
hidden_states
=
hidden_states
.
to
(
"cuda"
,
non_blocking
=
True
).
detach
()
hidden_states
=
hidden_states
.
to
(
"cuda"
,
non_blocking
=
True
).
detach
()
hidden_states
.
requires_grad_
(
True
)
hidden_states
.
requires_grad_
(
True
)
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
(
output
,)
=
ctx
.
forward_function
(
hidden_states
,
*
ctx
.
args
)
outputs
=
ctx
.
forward_function
(
hidden_states
,
*
ctx
.
args
)
output
=
outputs
[
0
]
if
isinstance
(
outputs
,
tuple
)
else
outputs
torch
.
autograd
.
backward
(
output
,
grad_output
)
torch
.
autograd
.
backward
(
output
,
grad_output
)
return
(
None
,
hidden_states
.
grad
)
+
(
None
,)
*
len
(
ctx
.
args
)
return
(
None
,
hidden_states
.
grad
)
+
(
None
,)
*
len
(
ctx
.
args
)
...
...
src/llamafactory/model/model_utils/longlora.py
View file @
c7d1b209
...
@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional
...
@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
transformers
import
transformers
from
transformers.models.llama.modeling_llama
import
Cache
,
apply_rotary_pos_emb
,
repeat_kv
from
...extras
import
logging
from
...extras
import
logging
from
...extras.constants
import
SUPPORTED_CLASS_FOR_S2ATTN
from
...extras.constants
import
SUPPORTED_CLASS_FOR_S2ATTN
...
@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than
...
@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than
if
not
is_transformers_version_greater_than
(
"4.48.0"
):
if
not
is_transformers_version_greater_than
(
"4.48.0"
):
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaFlashAttention2
,
LlamaSdpaAttention
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
from
transformers.models.llama.modeling_llama
import
(
Cache
,
LlamaAttention
,
LlamaFlashAttention2
,
LlamaSdpaAttention
,
apply_rotary_pos_emb
,
repeat_kv
,
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward(
...
@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward(
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
:
groupsz
].
repeat
(
num_groups
,
1
)
attention_mask
=
attention_mask
[:,
:
groupsz
].
repeat
(
num_groups
,
1
)
if
is_transformers_version_greater_than
(
"4.43.0"
):
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
attn_output
:
torch
.
Tensor
=
_flash_attention_forward
(
attn_output
:
torch
.
Tensor
=
_flash_attention_forward
(
query_states
,
query_states
,
key_states
,
key_states
,
...
@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward(
...
@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward(
use_top_left_mask
=
self
.
_flash_attn_uses_top_left_mask
,
use_top_left_mask
=
self
.
_flash_attn_uses_top_left_mask
,
is_causal
=
self
.
is_causal
,
is_causal
=
self
.
is_causal
,
)
)
else
:
attn_output
:
torch
.
Tensor
=
self
.
_flash_attention_forward
(
query_states
,
key_states
,
value_states
,
attention_mask
,
query_states
.
size
(
1
),
dropout
=
dropout_rate
)
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift back
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift back
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
...
@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
...
@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
def
_apply_llama_patch
()
->
None
:
def
_apply_llama_patch
()
->
None
:
check_version
(
"transformers>=4.4
1.2
,<4.48.0"
)
check_version
(
"transformers>=4.4
5.0
,<4.48.0"
,
mandatory
=
True
)
LlamaAttention
.
forward
=
llama_attention_forward
LlamaAttention
.
forward
=
llama_attention_forward
LlamaFlashAttention2
.
forward
=
llama_flash_attention_2_forward
LlamaFlashAttention2
.
forward
=
llama_flash_attention_2_forward
LlamaSdpaAttention
.
forward
=
llama_sdpa_attention_forward
LlamaSdpaAttention
.
forward
=
llama_sdpa_attention_forward
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment