Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
c7d1b209
Commit
c7d1b209
authored
Apr 29, 2025
by
chenych
Browse files
Update 0429
parent
c8d12c06
Changes
65
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
760 additions
and
336 deletions
+760
-336
scripts/api_example/test_toolcall.py
scripts/api_example/test_toolcall.py
+2
-2
scripts/convert_ckpt/llamafy_baichuan2.py
scripts/convert_ckpt/llamafy_baichuan2.py
+1
-1
src/llamafactory/__init__.py
src/llamafactory/__init__.py
+0
-15
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+0
-2
src/llamafactory/cli.py
src/llamafactory/cli.py
+75
-92
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+1
-1
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+335
-147
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+76
-22
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+141
-3
src/llamafactory/extras/logging.py
src/llamafactory/extras/logging.py
+1
-1
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+22
-16
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+8
-0
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+16
-3
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+33
-9
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+15
-0
src/llamafactory/model/adapter.py
src/llamafactory/model/adapter.py
+7
-3
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+10
-3
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+3
-3
src/llamafactory/model/model_utils/checkpointing.py
src/llamafactory/model/model_utils/checkpointing.py
+4
-3
src/llamafactory/model/model_utils/longlora.py
src/llamafactory/model/model_utils/longlora.py
+10
-10
No files found.
scripts/api_example/test_toolcall.py
View file @
c7d1b209
...
...
@@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float:
def
main
():
client
=
OpenAI
(
api_key
=
"{}"
.
format
(
os
.
environ
.
get
(
"API_KEY"
,
"0"
)),
base_url
=
"http://localhost:{}/v1"
.
format
(
os
.
environ
.
get
(
"API_PORT"
,
8000
)),
api_key
=
"{}"
.
format
(
os
.
get
env
(
"API_KEY"
,
"0"
)),
base_url
=
"http://localhost:{}/v1"
.
format
(
os
.
get
env
(
"API_PORT"
,
8000
)),
)
tools
=
[
{
...
...
scripts/convert_ckpt/llamafy_baichuan2.py
View file @
c7d1b209
...
...
@@ -32,7 +32,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
baichuan2_state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
for
filepath
in
tqdm
(
os
.
listdir
(
input_dir
),
desc
=
"Load weights"
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
input_dir
,
filepath
))
and
filepath
.
endswith
(
".bin"
):
shard_weight
=
torch
.
load
(
os
.
path
.
join
(
input_dir
,
filepath
),
map_location
=
"cpu"
)
shard_weight
=
torch
.
load
(
os
.
path
.
join
(
input_dir
,
filepath
),
map_location
=
"cpu"
,
weights_only
=
True
)
baichuan2_state_dict
.
update
(
shard_weight
)
llama_state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
...
...
src/llamafactory/__init__.py
View file @
c7d1b209
...
...
@@ -17,23 +17,8 @@ r"""Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.5.0
accelerate>=0.34.0,<=1.6.0
peft>=0.14.0,<=0.15.1
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<4.48.0
packing:
transformers>=4.43.0
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
...
...
src/llamafactory/chat/hf_engine.py
View file @
c7d1b209
...
...
@@ -25,7 +25,6 @@ from typing_extensions import override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
,
EngineName
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
...
...
@@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine):
inputs
=
inputs
,
attention_mask
=
attention_mask
,
generation_config
=
GenerationConfig
(
**
generating_args
),
logits_processor
=
get_logits_processor
(),
)
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
batch_ids
=
[
prompt_ids
],
processor
=
processor
)
...
...
src/llamafactory/cli.py
View file @
c7d1b209
...
...
@@ -16,17 +16,7 @@ import os
import
subprocess
import
sys
from
copy
import
deepcopy
from
enum
import
Enum
,
unique
from
.
import
launcher
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
from
functools
import
partial
USAGE
=
(
...
...
@@ -44,98 +34,91 @@ USAGE = (
+
"-"
*
70
)
WELCOME
=
(
"-"
*
58
+
"
\n
"
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
" "
*
56
+
"|
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"-"
*
58
)
logger
=
logging
.
get_logger
(
__name__
)
def
main
():
from
.
import
launcher
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
logger
=
logging
.
get_logger
(
__name__
)
@
unique
class
Command
(
str
,
Enum
):
API
=
"api"
CHAT
=
"chat"
ENV
=
"env"
EVAL
=
"eval"
EXPORT
=
"export"
TRAIN
=
"train"
WEBDEMO
=
"webchat"
WEBUI
=
"webui"
VER
=
"version"
HELP
=
"help"
WELCOME
=
(
"-"
*
58
+
"
\n
"
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
" "
*
56
+
"|
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"-"
*
58
)
COMMAND_MAP
=
{
"api"
:
run_api
,
"chat"
:
run_chat
,
"env"
:
print_env
,
"eval"
:
run_eval
,
"export"
:
export_model
,
"train"
:
run_exp
,
"webchat"
:
run_web_demo
,
"webui"
:
run_web_ui
,
"version"
:
partial
(
print
,
WELCOME
),
"help"
:
partial
(
print
,
USAGE
),
}
def
main
():
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
!=
1
else
Command
.
HELP
if
command
==
Command
.
API
:
run_api
()
elif
command
==
Command
.
CHAT
:
run_chat
()
elif
command
==
Command
.
ENV
:
print_env
()
elif
command
==
Command
.
EVAL
:
run_eval
()
elif
command
==
Command
.
EXPORT
:
export_model
()
elif
command
==
Command
.
TRAIN
:
force_torchrun
=
is_env_enabled
(
"FORCE_TORCHRUN"
)
if
force_torchrun
or
(
get_device_count
()
>
1
and
not
use_ray
()):
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
find_available_port
()))
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
if
int
(
nnodes
)
>
1
:
print
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>=
1
else
"help"
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
# launch distributed training
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
find_available_port
()))
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
if
int
(
nnodes
)
>
1
:
print
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
env
=
deepcopy
(
os
.
environ
)
if
is_env_enabled
(
"OPTIM_TORCH"
,
"1"
):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env
[
"PYTORCH_CUDA_ALLOC_CONF"
]
=
"expandable_segments:True"
env
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
env
=
deepcopy
(
os
.
environ
)
if
is_env_enabled
(
"OPTIM_TORCH"
,
"1"
):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env
[
"PYTORCH_CUDA_ALLOC_CONF"
]
=
"expandable_segments:True"
env
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
# NOTE: DO NOT USE shell=True to avoid security risk
process
=
subprocess
.
run
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.
format
(
nnodes
=
nnodes
,
node_rank
=
node_rank
,
nproc_per_node
=
nproc_per_node
,
master_addr
=
master_addr
,
master_port
=
master_port
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
.
split
(),
env
=
env
,
check
=
True
,
# NOTE: DO NOT USE shell=True to avoid security risk
process
=
subprocess
.
run
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.
format
(
nnodes
=
nnodes
,
node_rank
=
node_rank
,
nproc_per_node
=
nproc_per_node
,
master_addr
=
master_addr
,
master_port
=
master_port
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
sys
.
exit
(
process
.
returncode
)
else
:
run_exp
()
elif
command
==
Command
.
WEBDEMO
:
run_web_demo
()
elif
command
==
Command
.
WEBUI
:
run_web_ui
()
elif
command
==
Command
.
VER
:
print
(
WELCOME
)
elif
command
==
Command
.
HELP
:
print
(
USAGE
)
.
split
(),
env
=
env
,
check
=
True
,
)
sys
.
exit
(
process
.
returncode
)
elif
command
in
COMMAND_MAP
:
COMMAND_MAP
[
command
]()
else
:
print
(
f
"Unknown command:
{
command
}
.
\n
{
USAGE
}
"
)
if
__name__
==
"__main__"
:
from
multiprocessing
import
freeze_support
freeze_support
()
main
()
src/llamafactory/data/collator.py
View file @
c7d1b209
...
...
@@ -176,7 +176,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"input_ids"
:
features
[
"input_ids"
],
"image_grid_thw"
:
mm_inputs
.
get
(
"image_grid_thw"
),
"video_grid_thw"
:
mm_inputs
.
get
(
"video_grid_thw"
),
"attention_mask"
:
features
[
"attention_mask"
],
"attention_mask"
:
(
features
[
"attention_mask"
]
>=
1
).
float
()
,
}
if
"second_per_grid_ts"
in
mm_inputs
:
# for qwen2vl
rope_index_kwargs
[
"second_per_grid_ts"
]
=
mm_inputs
.
get
(
"second_per_grid_ts"
)
...
...
src/llamafactory/data/mm_plugin.py
View file @
c7d1b209
This diff is collapsed.
Click to expand it.
src/llamafactory/data/template.py
View file @
c7d1b209
...
...
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras.misc
import
check_version
from
.data_utils
import
Role
from
.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
from
.mm_plugin
import
get_mm_plugin
...
...
@@ -61,7 +61,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
)
->
tuple
[
list
[
int
],
list
[
int
]]:
r
"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
True
)
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
...
...
@@ -77,7 +77,7 @@ class Template:
tools
:
Optional
[
str
]
=
None
,
)
->
list
[
tuple
[
list
[
int
],
list
[
int
]]]:
r
"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
,
remove_thought
=
False
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
...
...
@@ -111,12 +111,18 @@ class Template:
return
token_ids
def
_remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
pattern
=
re
.
compile
(
f
"
{
re
.
escape
(
self
.
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
self
.
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
return
re
.
sub
(
pattern
,
""
,
content
).
lstrip
(
"
\n
"
)
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
r
"""Encode formatted inputs to pairs of token ids.
...
...
@@ -134,14 +140,18 @@ class Template:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
if
message
[
"role"
]
==
Role
.
USER
.
value
:
elements
+=
self
.
format_user
.
apply
(
content
=
message
[
"content"
],
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
ASSISTANT
.
value
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
OBSERVATION
.
value
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
FUNCTION
.
value
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
content
=
message
[
"content"
]
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
content
=
self
.
_remove_thought
(
content
)
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
content
,
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
...
@@ -318,6 +328,7 @@ class Llama2Template(Template):
messages
:
list
[
dict
[
str
,
str
]],
system
:
str
,
tools
:
str
,
remove_thought
:
bool
,
)
->
list
[
list
[
int
]]:
system
=
system
or
self
.
default_system
encoded_messages
=
[]
...
...
@@ -331,14 +342,18 @@ class Llama2Template(Template):
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
if
message
[
"role"
]
==
Role
.
USER
.
value
:
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
ASSISTANT
.
value
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
OBSERVATION
.
value
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
FUNCTION
.
value
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
content
=
message
[
"content"
]
if
remove_thought
and
message
[
"role"
]
==
Role
.
ASSISTANT
and
(
i
!=
len
(
messages
)
-
1
):
content
=
self
.
_remove_thought
(
content
)
if
message
[
"role"
]
==
Role
.
USER
:
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
content
)
elif
message
[
"role"
]
==
Role
.
ASSISTANT
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
content
)
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
content
)
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
content
)
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
...
@@ -477,6 +492,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages
=
[{
"role"
:
"user"
,
"content"
:
"{{content}}"
},
{
"role"
:
"assistant"
,
"content"
:
"{{content}}"
}]
assistant_slot
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
False
,
tokenize
=
False
)
assistant_slot
=
assistant_slot
[
len
(
prefix
)
+
len
(
user_slot
)
:]
assistant_slot
=
assistant_slot
.
replace
(
"<think>"
,
""
).
replace
(
"</think>"
,
""
).
lstrip
(
"
\n
"
)
# remove thought tags
if
len
(
user_slot
)
>
len
(
user_slot_empty_system
):
default_system
=
find_diff
(
user_slot_empty_system
,
user_slot
)
...
...
@@ -518,9 +534,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template
=
TEMPLATES
[
data_args
.
template
]
if
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
check_version
(
"transformers>=4.45.0"
)
if
data_args
.
train_on_prompt
and
template
.
efficient_eos
:
raise
ValueError
(
"Current template does not support `train_on_prompt`."
)
...
...
@@ -871,6 +884,18 @@ register_template(
)
register_template
(
name
=
"granite3_vision"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}
\n
<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}
\n
"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
register_template
(
name
=
"index"
,
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
...
...
@@ -923,6 +948,20 @@ register_template(
)
register_template
(
name
=
"intern_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
),
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"intern_vl"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
register_template
(
name
=
"kimi_vl"
,
format_user
=
StringFormatter
(
...
...
@@ -1389,6 +1428,21 @@ register_template(
)
# copied from qwen template
register_template
(
name
=
"qwen3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
)
# copied from chatml template
register_template
(
name
=
"qwen2_audio"
,
...
...
src/llamafactory/extras/constants.py
View file @
c7d1b209
...
...
@@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
AUDIO_PLACEHOLDER
=
os
.
environ
.
get
(
"AUDIO_PLACEHOLDER"
,
"<audio>"
)
AUDIO_PLACEHOLDER
=
os
.
get
env
(
"AUDIO_PLACEHOLDER"
,
"<audio>"
)
CHECKPOINT_NAMES
=
{
SAFE_ADAPTER_WEIGHTS_NAME
,
...
...
@@ -50,7 +50,7 @@ FILEEXT2TYPE = {
IGNORE_INDEX
=
-
100
IMAGE_PLACEHOLDER
=
os
.
environ
.
get
(
"IMAGE_PLACEHOLDER"
,
"<image>"
)
IMAGE_PLACEHOLDER
=
os
.
get
env
(
"IMAGE_PLACEHOLDER"
,
"<image>"
)
LAYERNORM_NAMES
=
{
"norm"
,
"ln"
}
...
...
@@ -89,7 +89,7 @@ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SWANLAB_CONFIG
=
"swanlab_public_config.json"
VIDEO_PLACEHOLDER
=
os
.
environ
.
get
(
"VIDEO_PLACEHOLDER"
,
"<video>"
)
VIDEO_PLACEHOLDER
=
os
.
get
env
(
"VIDEO_PLACEHOLDER"
,
"<video>"
)
V_HEAD_WEIGHTS_NAME
=
"value_head.bin"
...
...
@@ -838,11 +838,46 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-8b-instruct"
,
},
"Granite-3.2-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.2-2b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.2-2b-instruct"
,
},
"Granite-3.2-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.2-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.2-8b-instruct"
,
},
"Granite-3.3-2B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.3-2b-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.3-2b-base"
,
},
"Granite-3.3-8B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.3-8b-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.3-8b-base"
,
},
"Granite-3.3-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.3-2b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.3-2b-instruct"
,
},
"Granite-3.3-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.3-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.3-8b-instruct"
,
},
},
template
=
"granite3"
,
)
register_model_group
(
models
=
{
"Granite-3.2-1B-A400M-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-vision-3.2-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-vision-3.2-2b"
,
},
},
template
=
"granite3_vision"
,
)
register_model_group
(
models
=
{
"Hunyuan-7B-Instruct"
:
{
...
...
@@ -965,6 +1000,46 @@ register_model_group(
)
register_model_group
(
models
=
{
"InternVL2.5-2B-MPO"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL2_5-2B-MPO-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL2_5-2B-MPO-hf"
,
},
"InternVL2.5-8B-MPO"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL2_5-8B-MPO-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL2_5-8B-MPO-hf"
,
},
"InternVL3-1B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-1B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-1B-hf"
,
},
"InternVL3-2B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-2B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-2B-hf"
,
},
"InternVL3-8B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-8B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-8B-hf"
,
},
"InternVL3-14B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-14B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-14B-hf"
,
},
"InternVL3-38B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-38B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-38B-hf"
,
},
"InternVL3-78B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-78B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-78B-hf"
,
},
},
template
=
"intern_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Jamba-v0.1"
:
{
...
...
@@ -2328,6 +2403,69 @@ register_model_group(
)
register_model_group
(
models
=
{
"Qwen3-0.6B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B-Base"
,
},
"Qwen3-1.7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B-Base"
,
},
"Qwen3-4B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-Base"
,
},
"Qwen3-8B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B-Base"
,
},
"Qwen3-14B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B-Base"
,
},
"Qwen3-30B-A3B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-Base"
,
},
"Qwen3-0.6B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B"
,
},
"Qwen3-1.7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B"
,
},
"Qwen3-4B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B"
,
},
"Qwen3-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B"
,
},
"Qwen3-14B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B"
,
},
"Qwen3-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-32B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-32B"
,
},
"Qwen3-30B-A3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B"
,
},
"Qwen3-235B-A22B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
},
},
template
=
"qwen3"
,
)
register_model_group
(
models
=
{
"Qwen2-Audio-7B"
:
{
...
...
src/llamafactory/extras/logging.py
View file @
c7d1b209
...
...
@@ -79,7 +79,7 @@ class _Logger(logging.Logger):
def
_get_default_logging_level
()
->
"logging._Level"
:
r
"""Return the default logging level."""
env_level_str
=
os
.
environ
.
get
(
"LLAMAFACTORY_VERBOSITY"
,
None
)
env_level_str
=
os
.
get
env
(
"LLAMAFACTORY_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
.
upper
()
in
logging
.
_nameToLevel
:
return
logging
.
_nameToLevel
[
env_level_str
.
upper
()]
...
...
src/llamafactory/extras/misc.py
View file @
c7d1b209
...
...
@@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.4
1.2
,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"transformers>=4.4
5.0
,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"datasets>=2.16.0,<=3.5.0"
)
check_version
(
"accelerate>=0.34.0,<=1.6.0"
)
check_version
(
"peft>=0.14.0,<=0.15.1"
)
...
...
@@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
def
get_current_device
()
->
"torch.device"
:
r
"""Get the current available device."""
if
is_torch_xpu_available
():
device
=
"xpu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"xpu:{}"
.
format
(
os
.
get
env
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_npu_available
():
device
=
"npu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"npu:{}"
.
format
(
os
.
get
env
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_mps_available
():
device
=
"mps:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"mps:{}"
.
format
(
os
.
get
env
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_cuda_available
():
device
=
"cuda:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
device
=
"cuda:{}"
.
format
(
os
.
get
env
(
"LOCAL_RANK"
,
"0"
))
else
:
device
=
"cpu"
...
...
@@ -155,11 +155,13 @@ def get_current_device() -> "torch.device":
def
get_device_count
()
->
int
:
r
"""Get the number of available
GPU or NPU
devices."""
r
"""Get the number of available devices."""
if
is_torch_xpu_available
():
return
torch
.
xpu
.
device_count
()
elif
is_torch_npu_available
():
return
torch
.
npu
.
device_count
()
elif
is_torch_mps_available
():
return
torch
.
mps
.
device_count
()
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
device_count
()
else
:
...
...
@@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList":
def
get_peak_memory
()
->
tuple
[
int
,
int
]:
r
"""Get the peak memory usage for the current device (in Bytes)."""
if
is_torch_npu_available
():
return
torch
.
npu
.
max_memory_allocated
(),
torch
.
npu
.
max_memory_reserved
()
elif
is_torch_xpu_available
():
if
is_torch_xpu_available
():
return
torch
.
xpu
.
max_memory_allocated
(),
torch
.
xpu
.
max_memory_reserved
()
elif
is_torch_npu_available
():
return
torch
.
npu
.
max_memory_allocated
(),
torch
.
npu
.
max_memory_reserved
()
elif
is_torch_mps_available
():
return
torch
.
mps
.
current_allocated_memory
(),
-
1
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
max_memory_allocated
(),
torch
.
cuda
.
max_memory_reserved
()
else
:
...
...
@@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
return
torch
.
float32
def
is_gpu_or_npu_available
()
->
bool
:
r
"""Check if the GPU or NPU is available."""
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
or
is_torch_xpu_available
()
def
is_accelerator_available
()
->
bool
:
r
"""Check if the accelerator is available."""
return
(
is_torch_xpu_available
()
or
is_torch_npu_available
()
or
is_torch_mps_available
()
or
is_torch_cuda_available
()
)
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
...
...
@@ -229,7 +235,7 @@ def skip_check_imports() -> None:
def
torch_gc
()
->
None
:
r
"""Collect
GPU or NPU
memory."""
r
"""Collect
the device
memory."""
gc
.
collect
()
if
is_torch_xpu_available
():
torch
.
xpu
.
empty_cache
()
...
...
@@ -280,7 +286,7 @@ def use_ray() -> bool:
def
find_available_port
()
->
int
:
"""Find an available port on the local machine."""
r
"""Find an available port on the local machine."""
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
bind
((
""
,
0
))
port
=
sock
.
getsockname
()[
1
]
...
...
@@ -288,8 +294,8 @@ def find_available_port() -> int:
return
port
def
fix_proxy
(
ipv6_enabled
:
bool
)
->
None
:
"""Fix proxy settings for gradio ui."""
def
fix_proxy
(
ipv6_enabled
:
bool
=
False
)
->
None
:
r
"""Fix proxy settings for gradio ui."""
os
.
environ
[
"no_proxy"
]
=
"localhost,127.0.0.1,0.0.0.0"
if
ipv6_enabled
:
for
name
in
(
"http_proxy"
,
"https_proxy"
,
"HTTP_PROXY"
,
"HTTPS_PROXY"
):
...
...
src/llamafactory/hparams/finetuning_args.py
View file @
c7d1b209
...
...
@@ -411,6 +411,10 @@ class FinetuningArguments(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the Adam-mini optimizer."
},
)
use_muon
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the Muon optimizer."
},
)
freeze_vision_tower
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether ot not to freeze the vision tower in MLLM training."
},
...
...
@@ -431,6 +435,10 @@ class FinetuningArguments(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to disable the shuffling of the training set."
},
)
early_stopping_steps
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Number of steps to stop training if the `metric_for_best_model` does not improve."
},
)
plot_loss
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to save the training loss curves."
},
...
...
src/llamafactory/hparams/model_args.py
View file @
c7d1b209
...
...
@@ -65,7 +65,13 @@ class BaseModelArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether or not the special tokens should be split during the tokenization process."
},
)
new_special_tokens
:
Optional
[
str
]
=
field
(
add_tokens
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
add_special_tokens
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
...
...
@@ -176,8 +182,11 @@ class BaseModelArguments:
if
self
.
adapter_name_or_path
is
not
None
:
# support merging multiple lora weights
self
.
adapter_name_or_path
=
[
path
.
strip
()
for
path
in
self
.
adapter_name_or_path
.
split
(
","
)]
if
self
.
new_special_tokens
is
not
None
:
# support multiple special tokens
self
.
new_special_tokens
=
[
token
.
strip
()
for
token
in
self
.
new_special_tokens
.
split
(
","
)]
if
self
.
add_tokens
is
not
None
:
# support multiple tokens
self
.
add_tokens
=
[
token
.
strip
()
for
token
in
self
.
add_tokens
.
split
(
","
)]
if
self
.
add_special_tokens
is
not
None
:
# support multiple special tokens
self
.
add_special_tokens
=
[
token
.
strip
()
for
token
in
self
.
add_special_tokens
.
split
(
","
)]
@
dataclass
...
...
@@ -222,6 +231,10 @@ class ProcessorArguments:
default
=
False
,
metadata
=
{
"help"
:
"Use pan and scan to process image for gemma3."
},
)
crop_to_patches
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to crop the image to patches for internvl."
},
)
use_audio_in_video
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use audio in video inputs."
},
...
...
src/llamafactory/hparams/parser.py
View file @
c7d1b209
...
...
@@ -24,6 +24,7 @@ from typing import Any, Optional, Union
import
torch
import
transformers
import
yaml
from
omegaconf
import
OmegaConf
from
transformers
import
HfArgumentParser
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.trainer_utils
import
get_last_checkpoint
...
...
@@ -59,10 +60,14 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if
args
is
not
None
:
return
args
if
len
(
sys
.
argv
)
==
2
and
(
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
)):
return
yaml
.
safe_load
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
elif
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
".json"
):
return
json
.
loads
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
if
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
yaml
.
safe_load
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
elif
sys
.
argv
[
1
].
endswith
(
".json"
):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
json
.
loads
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
return
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
else
:
return
sys
.
argv
[
1
:]
...
...
@@ -91,6 +96,14 @@ def _set_transformers_logging() -> None:
transformers
.
utils
.
logging
.
enable_explicit_format
()
def
_set_env_vars
()
->
None
:
if
is_torch_npu_available
():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch
.
npu
.
set_compile_mode
(
jit_compile
=
is_env_enabled
(
"NPU_JIT_COMPILE"
))
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
_verify_model_args
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
...
...
@@ -279,12 +292,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if
training_args
.
deepspeed
is
not
None
and
(
finetuning_args
.
use_galore
or
finetuning_args
.
use_apollo
):
raise
ValueError
(
"GaLore and APOLLO are incompatible with DeepSpeed yet."
)
if
model_args
.
infer_backend
=
=
"vllm"
:
raise
ValueError
(
"vLLM backend is only available for API, CLI and Web."
)
if
model_args
.
infer_backend
!
=
EngineName
.
HF
:
raise
ValueError
(
"vLLM
/SGLang
backend is only available for API, CLI and Web."
)
if
model_args
.
use_unsloth
and
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"Unsloth is incompatible with DeepSpeed ZeRO-3."
)
_set_env_vars
()
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
,
training_args
)
...
...
@@ -321,12 +335,20 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
logger
.
warning_rank0
(
"Specify `ref_model` for computing rewards at evaluation."
)
# Post-process training arguments
training_args
.
generation_max_length
=
training_args
.
generation_max_length
or
data_args
.
cutoff_len
training_args
.
generation_num_beams
=
data_args
.
eval_num_beams
or
training_args
.
generation_num_beams
training_args
.
remove_unused_columns
=
False
# important for multimodal dataset
if
finetuning_args
.
finetuning_type
==
"lora"
:
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args
.
label_names
=
training_args
.
label_names
or
[
"labels"
]
if
(
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
and
training_args
.
ddp_find_unused_parameters
is
None
and
finetuning_args
.
finetuning_type
==
"lora"
):
logger
.
warning
_rank0
(
"`ddp_find_unused_parameters`
needs to be set as False for LoRA in DDP training
."
)
logger
.
info
_rank0
(
"
Set
`ddp_find_unused_parameters`
to False in DDP training since LoRA is enabled
."
)
training_args
.
ddp_find_unused_parameters
=
False
if
finetuning_args
.
stage
in
[
"rm"
,
"ppo"
]
and
finetuning_args
.
finetuning_type
in
[
"full"
,
"freeze"
]:
...
...
@@ -407,6 +429,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if
model_args
.
adapter_name_or_path
is
not
None
and
len
(
model_args
.
adapter_name_or_path
)
!=
1
:
raise
ValueError
(
"vLLM only accepts a single adapter. Merge them first."
)
_set_env_vars
()
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
)
...
...
@@ -428,9 +451,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
_set_transformers_logging
()
# Check arguments
if
model_args
.
infer_backend
=
=
"vllm"
:
raise
ValueError
(
"vLLM backend is only available for API, CLI and Web."
)
if
model_args
.
infer_backend
!
=
EngineName
.
HF
:
raise
ValueError
(
"vLLM
/SGLang
backend is only available for API, CLI and Web."
)
_set_env_vars
()
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
)
...
...
src/llamafactory/hparams/training_args.py
View file @
c7d1b209
...
...
@@ -34,6 +34,10 @@ class RayArguments:
default
=
"./saves"
,
metadata
=
{
"help"
:
"The storage path to save training results to"
},
)
ray_storage_filesystem
:
Optional
[
Literal
[
"s3"
,
"gs"
,
"gcs"
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The storage filesystem to use. If None specified, local filesystem will be used."
},
)
ray_num_workers
:
int
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"The number of workers for Ray training. Default is 1 worker."
},
...
...
@@ -55,6 +59,17 @@ class RayArguments:
self
.
use_ray
=
use_ray
()
if
isinstance
(
self
.
resources_per_worker
,
str
)
and
self
.
resources_per_worker
.
startswith
(
"{"
):
self
.
resources_per_worker
=
_convert_str_dict
(
json
.
loads
(
self
.
resources_per_worker
))
if
self
.
ray_storage_filesystem
is
not
None
:
if
self
.
ray_storage_filesystem
not
in
[
"s3"
,
"gs"
,
"gcs"
]:
raise
ValueError
(
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
)
import
pyarrow.fs
as
fs
if
self
.
ray_storage_filesystem
==
"s3"
:
self
.
ray_storage_filesystem
=
fs
.
S3FileSystem
()
elif
self
.
ray_storage_filesystem
==
"gs"
or
self
.
ray_storage_filesystem
==
"gcs"
:
self
.
ray_storage_filesystem
=
fs
.
GcsFileSystem
()
@
dataclass
...
...
src/llamafactory/model/adapter.py
View file @
c7d1b209
...
...
@@ -23,7 +23,7 @@ from ..extras import logging
from
.model_utils.misc
import
find_all_linear_modules
,
find_expanded_modules
from
.model_utils.quantization
import
QuantizationMethod
from
.model_utils.unsloth
import
get_unsloth_peft_model
,
load_unsloth_peft_model
from
.model_utils.visual
import
get_forbidden_modules
,
patch_target_modules
from
.model_utils.visual
import
COMPOSITE_MODELS
,
get_forbidden_modules
,
patch_target_modules
if
TYPE_CHECKING
:
...
...
@@ -100,7 +100,7 @@ def _setup_freeze_tuning(
hidden_modules
.
add
(
name
.
split
(
".1."
)[
-
1
].
split
(
"."
)[
0
])
if
re
.
search
(
r
"\.\d+\."
,
name
)
is
None
:
non_hidden_modules
.
add
(
name
.
split
(
"."
)[
-
2
])
non_hidden_modules
.
add
(
name
.
split
(
"."
)[
-
2
])
# remove weight/bias
trainable_layers
=
[]
for
module_name
in
finetuning_args
.
freeze_trainable_modules
:
...
...
@@ -121,6 +121,10 @@ def _setup_freeze_tuning(
trainable_layers
.
append
(
module_name
)
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
not
finetuning_args
.
freeze_multi_modal_projector
and
model_type
in
COMPOSITE_MODELS
:
trainable_layers
.
append
(
COMPOSITE_MODELS
[
model_type
].
projector_key
)
forbidden_modules
=
get_forbidden_modules
(
model
.
config
,
finetuning_args
)
for
name
,
param
in
model
.
named_parameters
():
if
any
(
trainable_layer
in
name
for
trainable_layer
in
trainable_layers
)
and
not
any
(
...
...
@@ -204,7 +208,7 @@ def _setup_lora_tuning(
if
(
finetuning_args
.
use_dora
and
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
and
getattr
(
model
,
"quantization_method"
,
None
)
!=
QuantizationMethod
.
B
ITS_AND_BYTES
and
getattr
(
model
,
"quantization_method"
,
None
)
!=
QuantizationMethod
.
B
NB
):
raise
ValueError
(
"DoRA is not compatible with PTQ-quantized models."
)
...
...
src/llamafactory/model/loader.py
View file @
c7d1b209
...
...
@@ -19,7 +19,6 @@ import torch
from
transformers
import
(
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForImageTextToText
,
AutoModelForSeq2SeqLM
,
AutoModelForTextToWaveform
,
AutoModelForVision2Seq
,
...
...
@@ -30,6 +29,7 @@ from trl import AutoModelForCausalLMWithValueHead
from
..extras
import
logging
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_other_hub
from
..extras.packages
import
is_transformers_version_greater_than
from
.adapter
import
init_adapter
from
.model_utils.liger_kernel
import
apply_liger_kernel
from
.model_utils.misc
import
register_autoclass
...
...
@@ -39,6 +39,10 @@ from .model_utils.valuehead import load_valuehead_params
from
.patcher
import
patch_config
,
patch_model
,
patch_processor
,
patch_tokenizer
,
patch_valuehead_model
if
is_transformers_version_greater_than
(
"4.46.0"
):
from
transformers
import
AutoModelForImageTextToText
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
...
...
@@ -97,7 +101,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
patch_processor
(
processor
,
tokenizer
,
model_args
)
except
Exception
as
e
:
logger
.
debug
(
f
"Failed to load processor:
{
e
}
."
)
logger
.
info_rank0
(
f
"Failed to load processor:
{
e
}
."
)
processor
=
None
# Avoid load tokenizer, see:
...
...
@@ -145,7 +149,10 @@ def load_model(
else
:
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForVision2Seq
elif
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
# image-text
elif
(
is_transformers_version_greater_than
(
"4.46.0"
)
and
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
()
):
# image-text
load_class
=
AutoModelForImageTextToText
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
load_class
=
AutoModelForSeq2SeqLM
...
...
src/llamafactory/model/model_utils/attention.py
View file @
c7d1b209
...
...
@@ -18,7 +18,6 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl
from
...extras
import
logging
from
...extras.constants
import
AttentionFunction
from
...extras.misc
import
check_version
if
TYPE_CHECKING
:
...
...
@@ -36,8 +35,6 @@ def configure_attn_implementation(
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
is_flash_attn_2_available
():
check_version
(
"transformers>=4.42.4"
)
check_version
(
"flash_attn>=2.6.3"
)
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
logger
.
warning_rank0
(
"Gemma 2 should use flash attention 2, change `flash_attn` to fa2."
)
model_args
.
flash_attn
=
AttentionFunction
.
FA2
...
...
@@ -72,6 +69,9 @@ def configure_attn_implementation(
if
getattr
(
config
,
"model_type"
,
None
)
==
"internlm2"
:
# special case for custom models
setattr
(
config
,
"attn_implementation"
,
requested_attn_implementation
)
elif
getattr
(
config
,
"model_type"
,
None
)
==
"kimi_vl"
:
setattr
(
config
.
vision_config
,
"_attn_implementation"
,
requested_attn_implementation
)
setattr
(
config
.
text_config
,
"_attn_implementation"
,
requested_attn_implementation
)
else
:
setattr
(
config
,
"_attn_implementation"
,
requested_attn_implementation
)
...
...
src/llamafactory/model/model_utils/checkpointing.py
View file @
c7d1b209
...
...
@@ -52,12 +52,12 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
)
->
"torch.Tensor"
:
saved_hidden_states
=
hidden_states
.
to
(
"cpu"
,
non_blocking
=
True
)
with
torch
.
no_grad
():
output
=
forward_function
(
hidden_states
,
*
args
)
output
s
=
forward_function
(
hidden_states
,
*
args
)
ctx
.
save_for_backward
(
saved_hidden_states
)
ctx
.
forward_function
=
forward_function
ctx
.
args
=
args
return
output
return
output
s
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_bwd
...
...
@@ -66,7 +66,8 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
hidden_states
=
hidden_states
.
to
(
"cuda"
,
non_blocking
=
True
).
detach
()
hidden_states
.
requires_grad_
(
True
)
with
torch
.
enable_grad
():
(
output
,)
=
ctx
.
forward_function
(
hidden_states
,
*
ctx
.
args
)
outputs
=
ctx
.
forward_function
(
hidden_states
,
*
ctx
.
args
)
output
=
outputs
[
0
]
if
isinstance
(
outputs
,
tuple
)
else
outputs
torch
.
autograd
.
backward
(
output
,
grad_output
)
return
(
None
,
hidden_states
.
grad
)
+
(
None
,)
*
len
(
ctx
.
args
)
...
...
src/llamafactory/model/model_utils/longlora.py
View file @
c7d1b209
...
...
@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Optional
import
torch
import
torch.nn
as
nn
import
transformers
from
transformers.models.llama.modeling_llama
import
Cache
,
apply_rotary_pos_emb
,
repeat_kv
from
...extras
import
logging
from
...extras.constants
import
SUPPORTED_CLASS_FOR_S2ATTN
...
...
@@ -32,7 +31,15 @@ from ...extras.packages import is_transformers_version_greater_than
if
not
is_transformers_version_greater_than
(
"4.48.0"
):
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaFlashAttention2
,
LlamaSdpaAttention
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
from
transformers.models.llama.modeling_llama
import
(
Cache
,
LlamaAttention
,
LlamaFlashAttention2
,
LlamaSdpaAttention
,
apply_rotary_pos_emb
,
repeat_kv
,
)
if
TYPE_CHECKING
:
...
...
@@ -206,9 +213,6 @@ def llama_flash_attention_2_forward(
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
:
groupsz
].
repeat
(
num_groups
,
1
)
if
is_transformers_version_greater_than
(
"4.43.0"
):
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
attn_output
:
torch
.
Tensor
=
_flash_attention_forward
(
query_states
,
key_states
,
...
...
@@ -220,10 +224,6 @@ def llama_flash_attention_2_forward(
use_top_left_mask
=
self
.
_flash_attn_uses_top_left_mask
,
is_causal
=
self
.
is_causal
,
)
else
:
attn_output
:
torch
.
Tensor
=
self
.
_flash_attention_forward
(
query_states
,
key_states
,
value_states
,
attention_mask
,
query_states
.
size
(
1
),
dropout
=
dropout_rate
)
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift back
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
...
...
@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
def
_apply_llama_patch
()
->
None
:
check_version
(
"transformers>=4.4
1.2
,<4.48.0"
)
check_version
(
"transformers>=4.4
5.0
,<4.48.0"
,
mandatory
=
True
)
LlamaAttention
.
forward
=
llama_attention_forward
LlamaFlashAttention2
.
forward
=
llama_flash_attention_2_forward
LlamaSdpaAttention
.
forward
=
llama_sdpa_attention_forward
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment