Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
317a82e2
Commit
317a82e2
authored
Mar 07, 2025
by
chenych
Browse files
Add QWQ-32B
parent
37b0ad9f
Changes
255
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
311 additions
and
432 deletions
+311
-432
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+13
-15
src/llamafactory/data/utils.py
src/llamafactory/data/utils.py
+0
-78
src/llamafactory/eval/evaluator.py
src/llamafactory/eval/evaluator.py
+1
-1
src/llamafactory/eval/template.py
src/llamafactory/eval/template.py
+1
-1
src/llamafactory/extras/callbacks.py
src/llamafactory/extras/callbacks.py
+0
-231
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+220
-51
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+3
-1
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+18
-8
src/llamafactory/extras/packages.py
src/llamafactory/extras/packages.py
+4
-5
src/llamafactory/extras/ploting.py
src/llamafactory/extras/ploting.py
+1
-1
src/llamafactory/hparams/__init__.py
src/llamafactory/hparams/__init__.py
+1
-1
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+4
-4
src/llamafactory/hparams/evaluation_args.py
src/llamafactory/hparams/evaluation_args.py
+1
-1
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+3
-3
src/llamafactory/hparams/generating_args.py
src/llamafactory/hparams/generating_args.py
+1
-1
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+17
-9
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+16
-18
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+5
-1
src/llamafactory/launcher.py
src/llamafactory/launcher.py
+1
-1
src/llamafactory/model/__init__.py
src/llamafactory/model/__init__.py
+1
-1
No files found.
src/llamafactory/data/tool_utils.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union
from
typing_extensions
import
override
from
.data_utils
import
SLOTS
class
FunctionCall
(
NamedTuple
):
name
:
str
...
...
@@ -56,7 +54,7 @@ QWEN_TOOL_PROMPT = (
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>{tool_text}"
"
\n
</tools>
\n\n
For each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{{"name": <function-name>, """
""""arguments": <args-json-object>}}
\n
</tool_call>
<|im_end|>
\n
"""
""""arguments": <args-json-object>}}
\n
</tool_call>"""
)
...
...
@@ -76,7 +74,7 @@ class ToolUtils(ABC):
@
staticmethod
@
abstractmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
str
:
r
"""
Generates the assistant message including all the tool calls.
"""
...
...
@@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
str
:
function_text
=
""
for
name
,
arguments
in
functions
:
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
return
[
function_text
]
return
function_text
@
override
@
staticmethod
...
...
@@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"GLM-4 does not support parallel functions."
)
return
[
f
"
{
functions
[
0
].
name
}
\n
{
functions
[
0
].
arguments
}
"
]
return
f
"
{
functions
[
0
].
name
}
\n
{
functions
[
0
].
arguments
}
"
@
override
@
staticmethod
...
...
@@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
str
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
return
[
f
'{{"name": "
{
functions
[
0
].
name
}
", "parameters":
{
functions
[
0
].
arguments
}
}}'
]
return
f
'{{"name": "
{
functions
[
0
].
name
}
", "parameters":
{
functions
[
0
].
arguments
}
}}'
@
override
@
staticmethod
...
...
@@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
return
[
"["
+
", "
.
join
(
function_texts
)
+
"]"
]
return
"["
+
", "
.
join
(
function_texts
)
+
"]"
@
override
@
staticmethod
...
...
@@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils):
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
"<tool_call>
\n
"
+
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
+
"
\n
</tool_call>"
)
return
[
"
\n
"
.
join
(
function_texts
)
]
return
"
\n
"
.
join
(
function_texts
)
@
override
@
staticmethod
...
...
src/llamafactory/data/utils.py
deleted
100644 → 0
View file @
37b0ad9f
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
,
Union
from
datasets
import
concatenate_datasets
,
interleave_datasets
from
..extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
transformers
import
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
logger
=
get_logger
(
__name__
)
@
unique
class
Role
(
str
,
Enum
):
USER
=
"user"
ASSISTANT
=
"assistant"
SYSTEM
=
"system"
FUNCTION
=
"function"
OBSERVATION
=
"observation"
def
infer_max_len
(
source_len
:
int
,
target_len
:
int
,
max_len
:
int
,
reserved_label_len
:
int
)
->
Tuple
[
int
,
int
]:
max_target_len
=
int
(
max_len
*
(
target_len
/
(
source_len
+
target_len
)))
max_target_len
=
max
(
max_target_len
,
reserved_label_len
)
max_source_len
=
max_len
-
min
(
max_target_len
,
target_len
)
return
max_source_len
,
max_target_len
def
merge_dataset
(
all_datasets
:
List
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
if
len
(
all_datasets
)
==
1
:
return
all_datasets
[
0
]
elif
data_args
.
mix_strategy
==
"concat"
:
if
data_args
.
streaming
:
logger
.
warning
(
"The samples between different datasets will not be mixed in streaming mode."
)
return
concatenate_datasets
(
all_datasets
)
elif
data_args
.
mix_strategy
.
startswith
(
"interleave"
):
if
not
data_args
.
streaming
:
logger
.
warning
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
return
interleave_datasets
(
datasets
=
all_datasets
,
probabilities
=
data_args
.
interleave_probs
,
seed
=
training_args
.
seed
,
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
)
else
:
raise
ValueError
(
"Unknown mixing strategy."
)
def
split_dataset
(
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
)
->
Dict
[
str
,
"Dataset"
]:
if
training_args
.
do_train
:
if
data_args
.
val_size
>
1e-6
:
# Split the dataset
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
val_set
=
dataset
.
take
(
int
(
data_args
.
val_size
))
train_set
=
dataset
.
skip
(
int
(
data_args
.
val_size
))
return
{
"train_dataset"
:
train_set
,
"eval_dataset"
:
val_set
}
else
:
val_size
=
int
(
data_args
.
val_size
)
if
data_args
.
val_size
>
1
else
data_args
.
val_size
dataset
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
training_args
.
seed
)
return
{
"train_dataset"
:
dataset
[
"train"
],
"eval_dataset"
:
dataset
[
"test"
]}
else
:
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
return
{
"train_dataset"
:
dataset
}
else
:
# do_eval or do_predict
return
{
"eval_dataset"
:
dataset
}
src/llamafactory/eval/evaluator.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# This code is inspired by the Dan's test library.
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
...
...
src/llamafactory/eval/template.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/extras/callbacks.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
logging
import
os
import
signal
import
sys
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
datetime
import
timedelta
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
import
transformers
from
transformers
import
TrainerCallback
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
,
has_length
from
.constants
import
TRAINER_LOG
from
.logging
import
LoggerHandler
,
get_logger
from
.misc
import
fix_valuehead_checkpoint
if
TYPE_CHECKING
:
from
transformers
import
TrainerControl
,
TrainerState
,
TrainingArguments
logger
=
get_logger
(
__name__
)
class
FixValueHeadModelCallback
(
TrainerCallback
):
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a checkpoint save.
"""
if
args
.
should_save
:
fix_valuehead_checkpoint
(
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
)),
safe_serialization
=
args
.
save_safetensors
,
)
class
LogCallback
(
TrainerCallback
):
def
__init__
(
self
,
output_dir
:
str
)
->
None
:
r
"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self
.
start_time
=
0
self
.
cur_steps
=
0
self
.
max_steps
=
0
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
self
.
thread_pool
:
Optional
[
"ThreadPoolExecutor"
]
=
None
""" Status """
self
.
aborted
=
False
self
.
do_train
=
False
""" Web UI """
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
self
.
webui_mode
:
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
self
.
logger_handler
=
LoggerHandler
(
output_dir
)
logging
.
root
.
addHandler
(
self
.
logger_handler
)
transformers
.
logging
.
add_handler
(
self
.
logger_handler
)
def
_set_abort
(
self
,
signum
,
frame
)
->
None
:
self
.
aborted
=
True
def
_reset
(
self
,
max_steps
:
int
=
0
)
->
None
:
self
.
start_time
=
time
.
time
()
self
.
cur_steps
=
0
self
.
max_steps
=
max_steps
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
def
_timing
(
self
,
cur_steps
:
int
)
->
None
:
cur_time
=
time
.
time
()
elapsed_time
=
cur_time
-
self
.
start_time
avg_time_per_step
=
elapsed_time
/
cur_steps
if
cur_steps
!=
0
else
0
remaining_time
=
(
self
.
max_steps
-
cur_steps
)
*
avg_time_per_step
self
.
cur_steps
=
cur_steps
self
.
elapsed_time
=
str
(
timedelta
(
seconds
=
int
(
elapsed_time
)))
self
.
remaining_time
=
str
(
timedelta
(
seconds
=
int
(
remaining_time
)))
def
_write_log
(
self
,
output_dir
:
str
,
logs
:
Dict
[
str
,
Any
])
->
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINER_LOG
),
"a"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
logs
)
+
"
\n
"
)
def
_create_thread_pool
(
self
,
output_dir
:
str
)
->
None
:
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
self
.
thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
)
def
_close_thread_pool
(
self
)
->
None
:
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
shutdown
(
wait
=
True
)
self
.
thread_pool
=
None
def
on_init_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of the initialization of the `Trainer`.
"""
if
(
args
.
should_save
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
args
.
overwrite_output_dir
):
logger
.
warning
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the beginning of training.
"""
if
args
.
should_save
:
self
.
do_train
=
True
self
.
_reset
(
max_steps
=
state
.
max_steps
)
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of training.
"""
self
.
_close_thread_pool
()
def
on_substep_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of an substep during gradient accumulation.
"""
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
def
on_step_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of a training step.
"""
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
def
on_evaluate
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after an evaluation phase.
"""
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
def
on_predict
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a successful prediction.
"""
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
def
on_log
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after logging the last logs.
"""
if
not
args
.
should_save
:
return
self
.
_timing
(
cur_steps
=
state
.
global_step
)
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
loss
=
state
.
log_history
[
-
1
].
get
(
"loss"
,
None
),
eval_loss
=
state
.
log_history
[
-
1
].
get
(
"eval_loss"
,
None
),
predict_loss
=
state
.
log_history
[
-
1
].
get
(
"predict_loss"
,
None
),
reward
=
state
.
log_history
[
-
1
].
get
(
"reward"
,
None
),
accuracy
=
state
.
log_history
[
-
1
].
get
(
"rewards/accuracies"
,
None
),
learning_rate
=
state
.
log_history
[
-
1
].
get
(
"learning_rate"
,
None
),
epoch
=
state
.
log_history
[
-
1
].
get
(
"epoch"
,
None
),
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
throughput
=
"{:.2f}"
.
format
(
state
.
num_input_tokens_seen
/
(
time
.
time
()
-
self
.
start_time
)),
total_tokens
=
state
.
num_input_tokens_seen
,
)
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
[
"loss"
,
"learning_rate"
,
"epoch"
]):
logger
.
info
(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}"
.
format
(
logs
[
"loss"
],
logs
[
"learning_rate"
],
logs
[
"epoch"
],
logs
[
"throughput"
]
)
)
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
def
on_prediction_step
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a prediction step.
"""
if
self
.
do_train
:
return
if
self
.
aborted
:
sys
.
exit
(
0
)
if
not
args
.
should_save
:
return
eval_dataloader
=
kwargs
.
pop
(
"eval_dataloader"
,
None
)
if
has_length
(
eval_dataloader
):
if
self
.
max_steps
==
0
:
self
.
_reset
(
max_steps
=
len
(
eval_dataloader
))
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
self
.
_timing
(
cur_steps
=
self
.
cur_steps
+
1
)
if
self
.
cur_steps
%
5
==
0
and
self
.
thread_pool
is
not
None
:
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
src/llamafactory/extras/constants.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -22,6 +22,8 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
AUDIO_PLACEHOLDER
=
os
.
environ
.
get
(
"AUDIO_PLACEHOLDER"
,
"<audio>"
)
CHECKPOINT_NAMES
=
{
SAFE_ADAPTER_WEIGHTS_NAME
,
ADAPTER_WEIGHTS_NAME
,
...
...
@@ -58,6 +60,8 @@ METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS
=
{
"bloom"
,
"falcon"
,
"gemma"
,
"llama"
,
"mistral"
,
"mixtral"
,
"phi"
,
"starcoder2"
}
MULTIMODAL_SUPPORTED_MODELS
=
set
()
PEFT_METHODS
=
{
"lora"
}
RUNNING_LOG
=
"running_log.txt"
...
...
@@ -83,14 +87,14 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN
=
{
"llama"
}
SWANLAB_CONFIG
=
"swanlab_public_config.json"
VIDEO_PLACEHOLDER
=
os
.
environ
.
get
(
"VIDEO_PLACEHOLDER"
,
"<video>"
)
V_HEAD_WEIGHTS_NAME
=
"value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME
=
"value_head.safetensors"
VISION_MODELS
=
set
()
class
DownloadSource
(
str
,
Enum
):
DEFAULT
=
"hf"
...
...
@@ -101,14 +105,16 @@ class DownloadSource(str, Enum):
def
register_model_group
(
models
:
Dict
[
str
,
Dict
[
DownloadSource
,
str
]],
template
:
Optional
[
str
]
=
None
,
vision
:
bool
=
False
,
multimodal
:
bool
=
False
,
)
->
None
:
for
name
,
path
in
models
.
items
():
SUPPORTED_MODELS
[
name
]
=
path
if
template
is
not
None
and
(
any
(
suffix
in
name
for
suffix
in
(
"-Chat"
,
"-Instruct"
))
or
vision
):
if
template
is
not
None
and
(
any
(
suffix
in
name
for
suffix
in
(
"-Chat"
,
"-Distill"
,
"-Instruct"
))
or
multimodal
):
DEFAULT_TEMPLATE
[
name
]
=
template
if
vision
:
VISION
_MODELS
.
add
(
name
)
if
multimodal
:
MULTIMODAL_SUPPORTED
_MODELS
.
add
(
name
)
register_model_group
(
...
...
@@ -485,14 +491,46 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
},
"DeepSeek-V3-6
85
B-Base"
:
{
"DeepSeek-V3-6
71
B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3-Base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3-Base"
,
},
"DeepSeek-V3-6
85
B-Chat"
:
{
"DeepSeek-V3-6
71
B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3"
,
},
"DeepSeek-R1-1.5B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
,
},
"DeepSeek-R1-7B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
,
},
"DeepSeek-R1-8B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
},
"DeepSeek-R1-14B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
,
},
"DeepSeek-R1-32B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
,
},
"DeepSeek-R1-70B-Distill"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
,
},
"DeepSeek-R1-671B-Chat-Zero"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1-Zero"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1-Zero"
,
},
"DeepSeek-R1-671B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-R1"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-R1"
,
},
},
template
=
"deepseek3"
,
)
...
...
@@ -813,20 +851,15 @@ register_model_group(
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-20b-chat"
,
DownloadSource
.
OPENMIND
:
"Intern/internlm2_5-20b-chat"
,
},
},
template
=
"intern2"
,
)
register_model_group
(
models
=
{
"InternLM3-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm3-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm3-8b-instruct"
,
},
},
template
=
"intern
3
"
,
template
=
"intern
2
"
,
)
register_model_group
(
models
=
{
"Jamba-v0.1"
:
{
...
...
@@ -1003,7 +1036,7 @@ register_model_group(
},
},
template
=
"mllama"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1019,7 +1052,7 @@ register_model_group(
},
},
template
=
"llava"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1035,7 +1068,7 @@ register_model_group(
},
},
template
=
"llava_next"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1047,7 +1080,7 @@ register_model_group(
},
},
template
=
"llava_next_mistral"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1059,7 +1092,7 @@ register_model_group(
},
},
template
=
"llava_next_llama3"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1071,7 +1104,7 @@ register_model_group(
},
},
template
=
"llava_next_yi"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1087,7 +1120,7 @@ register_model_group(
},
},
template
=
"llava_next_qwen"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1103,7 +1136,7 @@ register_model_group(
},
},
template
=
"llava_next_video"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1115,7 +1148,7 @@ register_model_group(
},
},
template
=
"llava_next_video_mistral"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1130,7 +1163,7 @@ register_model_group(
},
},
template
=
"llava_next_video_yi"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1174,23 +1207,44 @@ register_model_group(
register_model_group
(
models
=
{
"MiniCPM-o-2_6
-Chat
"
:
{
"MiniCPM-o-2_6"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-o-2_6"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-o-2_6"
,
},
},
template
=
"minicpm_v"
,
template
=
"minicpm_o"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"MiniCPM-V-2_6
-Chat
"
:
{
"MiniCPM-V-2_6"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-V-2_6"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-V-2_6"
,
},
},
template
=
"minicpm_v"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Ministral-8B-Instruct-2410"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Ministral-8B-Instruct-2410"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Ministral-8B-Instruct-2410"
,
},
"Mistral-Nemo-Base-2407"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-Nemo-Base-2407"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Mistral-Nemo-Base-2407"
,
},
"Mistral-Nemo-Instruct-2407"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-Nemo-Instruct-2407"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-Nemo-Instruct-2407"
,
},
},
template
=
"ministral"
,
)
...
...
@@ -1200,48 +1254,60 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-v0.1"
,
},
"Mistral-7B-Instruct-v0.1"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-Instruct-v0.1"
,
},
"Mistral-7B-v0.2"
:
{
DownloadSource
.
DEFAULT
:
"alpindale/Mistral-7B-v0.2-hf"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-v0.2-hf"
,
},
"Mistral-7B-v0.3"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-v0.3"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/mistral-7b-v0.3"
,
},
"Mistral-7B-Instruct-v0.1"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-Instruct-v0.1"
,
},
"Mistral-7B-Instruct-v0.2"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.2"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-Instruct-v0.2"
,
},
"Mistral-7B-v0.3"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-v0.3"
,
},
"Mistral-7B-Instruct-v0.3"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Mistral-7B-Instruct-v0.3"
,
},
"Mistral-Nemo-Instruct-2407"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-Nemo-Instruct-2407"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-Nemo-Instruct-2407"
,
},
},
template
=
"mistral"
,
)
register_model_group
(
models
=
{
"Mistral-Small-24B-Base-2501"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-Small-24B-Base-2501"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Mistral-Small-24B-Base-2501"
,
},
"Mistral-Small-24B-Instruct-2501"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-Small-24B-Instruct-2501"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Mistral-Small-24B-Instruct-2501"
,
},
},
template
=
"mistral_small"
,
)
register_model_group
(
models
=
{
"Mixtral-8x7B-v0.1"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x7B-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x7B-v0.1"
,
},
"Mixtral-8x7B-v0.1-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x7B-Instruct-v0.1"
,
},
"Mixtral-8x22B-v0.1"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x22B-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x22B-v0.1"
,
},
"Mixtral-8x7B-v0.1-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x7B-Instruct-v0.1"
,
},
"Mixtral-8x22B-v0.1-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x22B-Instruct-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x22B-Instruct-v0.1"
,
...
...
@@ -1251,6 +1317,21 @@ register_model_group(
)
register_model_group
(
models
=
{
"Moonlight-16B-A3B"
:
{
DownloadSource
.
DEFAULT
:
"moonshotai/Moonlight-16B-A3B"
,
DownloadSource
.
MODELSCOPE
:
"moonshotai/Moonlight-16B-A3B"
,
},
"Moonlight-16B-A3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"moonshotai/Moonlight-16B-A3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"moonshotai/Moonlight-16B-A3B-Instruct"
,
},
},
template
=
"moonlight"
,
)
register_model_group
(
models
=
{
"OLMo-1B"
:
{
...
...
@@ -1364,7 +1445,7 @@ register_model_group(
},
},
template
=
"paligemma"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1406,9 +1487,33 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/paligemma2-28b-pt-896"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-28b-pt-896"
,
},
"PaliGemma2-3B-mix-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-3b-mix-224"
,
DownloadSource
.
MODELSCOPE
:
"mlx-community/paligemma2-3b-mix-224-bf16"
,
},
"PaliGemma2-3B-mix-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-3b-mix-448"
,
DownloadSource
.
MODELSCOPE
:
"mlx-community/paligemma2-3b-mix-448-bf16"
,
},
"PaliGemma2-10B-mix-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-10b-mix-224"
,
DownloadSource
.
MODELSCOPE
:
"mlx-community/paligemma2-10b-mix-224-bf16"
,
},
"PaliGemma2-10B-mix-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-10b-mix-448"
,
DownloadSource
.
MODELSCOPE
:
"mlx-community/paligemma2-10b-mix-448-bf16"
,
},
"PaliGemma2-28B-mix-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-28b-mix-224"
,
DownloadSource
.
MODELSCOPE
:
"mlx-community/paligemma2-28b-mix-224-bf16"
,
},
"PaliGemma2-28B-mix-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-28b-mix-448"
,
DownloadSource
.
MODELSCOPE
:
"mlx-community/paligemma2-28b-mix-448-bf16"
,
},
},
template
=
"paligemma"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1485,13 +1590,13 @@ register_model_group(
register_model_group
(
models
=
{
"Pixtral-12B
-Instruct
"
:
{
"Pixtral-12B"
:
{
DownloadSource
.
DEFAULT
:
"mistral-community/pixtral-12b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/pixtral-12b"
,
}
},
template
=
"pixtral"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -1901,6 +2006,14 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-72B-Instruct"
,
},
"Qwen2.5-7B-Instruct-1M"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-7B-Instruct-1M"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-7B-Instruct-1M"
,
},
"Qwen2.5-14B-Instruct-1M"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-14B-Instruct-1M"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-14B-Instruct-1M"
,
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8"
,
...
...
@@ -2061,6 +2174,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/QwQ-32B-Preview"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/QwQ-32B-Preview"
,
},
"QwQ-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/QwQ-32B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/QwQ-32B"
,
},
},
template
=
"qwen"
,
)
...
...
@@ -2068,6 +2185,34 @@ register_model_group(
register_model_group
(
models
=
{
"Qwen2-Audio-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Audio-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2-Audio-7B"
,
},
"Qwen2-Audio-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Audio-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2-Audio-7B-Instruct"
,
},
},
template
=
"qwen2_audio"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Qwen2-VL-2B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-2B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2-VL-2B"
,
},
"Qwen2-VL-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2-VL-7B"
,
},
"Qwen2-VL-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-72B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2-VL-72B"
,
},
"Qwen2-VL-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-2B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2-VL-2B-Instruct"
,
...
...
@@ -2122,9 +2267,33 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/QVQ-72B-Preview"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/QVQ-72B-Preview"
,
},
"Qwen2.5-VL-3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-3B-Instruct"
,
},
"Qwen2.5-VL-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-7B-Instruct"
,
},
"Qwen2.5-VL-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-72B-Instruct"
,
},
"Qwen2.5-VL-3B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-3B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-3B-Instruct-AWQ"
,
},
"Qwen2.5-VL-7B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-7B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-7B-Instruct-AWQ"
,
},
"Qwen2.5-VL-72B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-VL-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-VL-72B-Instruct-AWQ"
,
},
},
template
=
"qwen2_vl"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -2249,7 +2418,7 @@ register_model_group(
},
},
template
=
"video_llava"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
@@ -2476,7 +2645,7 @@ register_model_group(
},
},
template
=
"yi_vl"
,
vision
=
True
,
multimodal
=
True
,
)
...
...
src/llamafactory/extras/env.py
View file @
317a82e2
...
...
@@ -45,6 +45,8 @@ def print_env() -> None:
if
is_torch_cuda_available
():
info
[
"PyTorch version"
]
+=
" (GPU)"
info
[
"GPU type"
]
=
torch
.
cuda
.
get_device_name
()
info
[
"GPU number"
]
=
torch
.
cuda
.
device_count
()
info
[
"GPU memory"
]
=
f
"
{
torch
.
cuda
.
mem_get_info
()[
1
]
/
(
1024
**
3
):.
2
f
}
GB"
if
is_torch_npu_available
():
info
[
"PyTorch version"
]
+=
" (NPU)"
...
...
@@ -59,7 +61,7 @@ def print_env() -> None:
pass
try
:
import
bitsandbytes
import
bitsandbytes
# type: ignore
info
[
"Bitsandbytes version"
]
=
bitsandbytes
.
__version__
except
Exception
:
...
...
src/llamafactory/extras/misc.py
View file @
317a82e2
...
...
@@ -34,6 +34,7 @@ from transformers.utils import (
from
transformers.utils.versions
import
require_version
from
.
import
logging
from
.packages
import
is_transformers_version_greater_than
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
...
...
@@ -77,7 +78,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
r
"""
Optionally checks the package version.
"""
if
os
.
getenv
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
and
not
mandatory
:
if
is_env_enabled
(
"DISABLE_VERSION_CHECK"
)
and
not
mandatory
:
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
return
...
...
@@ -93,11 +94,13 @@ def check_dependencies() -> None:
r
"""
Checks the version of the required packages.
"""
check_version
(
"transformers>=4.41.2,<=4.4
6.1
"
)
check_version
(
"datasets>=2.16.0,<=3.
1
.0"
)
check_version
(
"accelerate>=0.34.0,<=1.
0
.1"
)
check_version
(
"transformers>=4.41.2,<=4.4
9.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0
"
)
check_version
(
"datasets>=2.16.0,<=3.
2
.0"
)
check_version
(
"accelerate>=0.34.0,<=1.
2
.1"
)
check_version
(
"peft>=0.11.1,<=0.12.0"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
def
calculate_tps
(
dataset
:
Sequence
[
Dict
[
str
,
Any
]],
metrics
:
Dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
...
...
@@ -223,6 +226,13 @@ def is_gpu_or_npu_available() -> bool:
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
r
"""
Checks if the environment variable is enabled.
"""
return
os
.
getenv
(
env_var
,
default
).
lower
()
in
[
"true"
,
"y"
,
"1"
]
def
numpify
(
inputs
:
Union
[
"NDArray"
,
"torch.Tensor"
])
->
"NDArray"
:
r
"""
Casts a torch tensor or a numpy array to a numpy array.
...
...
@@ -241,7 +251,7 @@ def skip_check_imports() -> None:
r
"""
Avoids flash attention import error in custom model files.
"""
if
os
.
getenv
(
"FORCE_CHECK_IMPORTS"
,
"0"
).
lower
()
not
in
[
"true"
,
"1"
]
:
if
not
is_env_enabled
(
"FORCE_CHECK_IMPORTS"
)
:
transformers
.
dynamic_module_utils
.
check_imports
=
get_relative_imports
...
...
@@ -287,12 +297,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def
use_modelscope
()
->
bool
:
return
os
.
getenv
(
"USE_MODELSCOPE_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
return
is_env_enabled
(
"USE_MODELSCOPE_HUB"
)
def
use_openmind
()
->
bool
:
return
os
.
getenv
(
"USE_OPENMIND_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
return
is_env_enabled
(
"USE_OPENMIND_HUB"
)
def
use_ray
()
->
bool
:
return
os
.
getenv
(
"USE_RAY"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
return
is_env_enabled
(
"USE_RAY"
)
src/llamafactory/extras/packages.py
View file @
317a82e2
...
...
@@ -42,6 +42,10 @@ def is_pyav_available():
return
_is_package_available
(
"av"
)
def
is_librosa_available
():
return
_is_package_available
(
"librosa"
)
def
is_fastapi_available
():
return
_is_package_available
(
"fastapi"
)
...
...
@@ -87,11 +91,6 @@ def is_transformers_version_greater_than(content: str):
return
_get_package_version
(
"transformers"
)
>=
version
.
parse
(
content
)
@
lru_cache
def
is_transformers_version_equal_to_4_46
():
return
version
.
parse
(
"4.46.0"
)
<=
_get_package_version
(
"transformers"
)
<=
version
.
parse
(
"4.46.1"
)
def
is_uvicorn_available
():
return
_is_package_available
(
"uvicorn"
)
...
...
src/llamafactory/extras/ploting.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/hparams/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/hparams/data_args.py
View file @
317a82e2
...
...
@@ -41,9 +41,9 @@ class DataArguments:
default
=
"data"
,
metadata
=
{
"help"
:
"Path to the folder containing the datasets."
},
)
image
_dir
:
Optional
[
str
]
=
field
(
media
_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the folder containing the images
or
videos. Defaults to `dataset_dir`."
},
metadata
=
{
"help"
:
"Path to the folder containing the images
,
videos
or audios
. Defaults to `dataset_dir`."
},
)
cutoff_len
:
int
=
field
(
default
=
2048
,
...
...
@@ -133,8 +133,8 @@ class DataArguments:
self
.
dataset
=
split_arg
(
self
.
dataset
)
self
.
eval_dataset
=
split_arg
(
self
.
eval_dataset
)
if
self
.
image
_dir
is
None
:
self
.
image
_dir
=
self
.
dataset_dir
if
self
.
media
_dir
is
None
:
self
.
media
_dir
=
self
.
dataset_dir
if
self
.
dataset
is
None
and
self
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `dataset` is None."
)
...
...
src/llamafactory/hparams/evaluation_args.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/hparams/finetuning_args.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -238,7 +238,7 @@ class GaloreArguments:
metadata
=
{
"help"
:
"Number of steps to update the GaLore projection."
},
)
galore_scale
:
float
=
field
(
default
=
0.25
,
default
=
2.0
,
metadata
=
{
"help"
:
"GaLore scaling coefficient."
},
)
galore_proj_type
:
Literal
[
"std"
,
"reverse_std"
,
"right"
,
"left"
,
"full"
]
=
field
(
...
...
@@ -279,7 +279,7 @@ class ApolloArguments:
metadata
=
{
"help"
:
"Number of steps to update the APOLLO projection."
},
)
apollo_scale
:
float
=
field
(
default
=
1
.0
,
default
=
32
.0
,
metadata
=
{
"help"
:
"APOLLO scaling coefficient."
},
)
apollo_proj
:
Literal
[
"svd"
,
"random"
]
=
field
(
...
...
src/llamafactory/hparams/generating_args.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/hparams/model_args.py
View file @
317a82e2
...
...
@@ -58,20 +58,28 @@ class ProcessorArguments:
Arguments pertaining to the image processor.
"""
image_
resolution
:
int
=
field
(
default
=
512
*
512
,
metadata
=
{
"help"
:
"
Keeps the
number of pixels of image
below this resolution
."
},
image_
max_pixels
:
int
=
field
(
default
=
768
*
768
,
metadata
=
{
"help"
:
"
The maximum
number of pixels of image
inputs
."
},
)
video_resolution
:
int
=
field
(
default
=
128
*
128
,
metadata
=
{
"help"
:
"Keeps the number of pixels of video below this resolution."
},
image_min_pixels
:
int
=
field
(
default
=
32
*
32
,
metadata
=
{
"help"
:
"The minimum number of pixels of image inputs."
},
)
video_max_pixels
:
int
=
field
(
default
=
256
*
256
,
metadata
=
{
"help"
:
"The maximum number of pixels of video inputs."
},
)
video_min_pixels
:
int
=
field
(
default
=
16
*
16
,
metadata
=
{
"help"
:
"The minimum number of pixels of video inputs."
},
)
video_fps
:
float
=
field
(
default
=
2.0
,
metadata
=
{
"help"
:
"The frames to sample per second for video inputs."
},
)
video_maxlen
:
int
=
field
(
default
=
64
,
default
=
128
,
metadata
=
{
"help"
:
"The maximum number of sampled frames for video inputs."
},
)
...
...
@@ -87,7 +95,7 @@ class ExportArguments:
metadata
=
{
"help"
:
"Path to the directory to save the exported model."
},
)
export_size
:
int
=
field
(
default
=
1
,
default
=
5
,
metadata
=
{
"help"
:
"The file shard size (in GB) of the exported model."
},
)
export_device
:
Literal
[
"cpu"
,
"auto"
]
=
field
(
...
...
@@ -201,7 +209,7 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use memory-efficient model loading."
},
)
rope_scaling
:
Optional
[
Literal
[
"linear"
,
"dynamic"
]]
=
field
(
rope_scaling
:
Optional
[
Literal
[
"linear"
,
"dynamic"
,
"yarn"
,
"llama3"
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Which scaling strategy should be adopted for the RoPE embeddings."
},
)
...
...
src/llamafactory/hparams/parser.py
View file @
317a82e2
...
...
@@ -32,7 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from
..extras
import
logging
from
..extras.constants
import
CHECKPOINT_NAMES
from
..extras.misc
import
check_dependencies
,
check_version
,
get_current_device
from
..extras.misc
import
check_dependencies
,
check_version
,
get_current_device
,
is_env_enabled
from
.data_args
import
DataArguments
from
.evaluation_args
import
EvaluationArguments
from
.finetuning_args
import
FinetuningArguments
...
...
@@ -55,6 +55,9 @@ _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
def
read_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
Union
[
Dict
[
str
,
Any
],
List
[
str
]]:
r
"""
Gets arguments from the command line or a config file.
"""
if
args
is
not
None
:
return
args
...
...
@@ -80,13 +83,14 @@ def _parse_args(
print
(
f
"Got unknown args, potentially deprecated arguments:
{
unknown_args
}
"
)
raise
ValueError
(
f
"Some specified arguments are not used by the HfArgumentParser:
{
unknown_args
}
"
)
return
(
*
parsed_args
,
)
return
tuple
(
parsed_args
)
def
_set_transformers_logging
()
->
None
:
transformers
.
utils
.
logging
.
set_verbosity_info
()
transformers
.
utils
.
logging
.
enable_default_handler
()
transformers
.
utils
.
logging
.
enable_explicit_format
()
if
os
.
getenv
(
"LLAMAFACTORY_VERBOSITY"
,
"INFO"
)
in
[
"DEBUG"
,
"INFO"
]:
transformers
.
utils
.
logging
.
set_verbosity_info
()
transformers
.
utils
.
logging
.
enable_default_handler
()
transformers
.
utils
.
logging
.
enable_explicit_format
()
def
_verify_model_args
(
...
...
@@ -133,7 +137,7 @@ def _check_extra_dependencies(
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
"vllm"
:
check_version
(
"vllm>=0.4.3,<=0.
6.5
"
)
check_version
(
"vllm>=0.4.3,<=0.
7.3
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
if
finetuning_args
.
use_galore
:
...
...
@@ -159,17 +163,20 @@ def _check_extra_dependencies(
def
_parse_train_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
parser
=
HfArgumentParser
(
_TRAIN_ARGS
)
return
_parse_args
(
parser
,
args
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
_parse_infer_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
_INFER_CLS
:
parser
=
HfArgumentParser
(
_INFER_ARGS
)
return
_parse_args
(
parser
,
args
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
_parse_eval_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
_EVAL_CLS
:
parser
=
HfArgumentParser
(
_EVAL_ARGS
)
return
_parse_args
(
parser
,
args
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
get_ray_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
RayArguments
:
...
...
@@ -186,9 +193,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
_set_transformers_logging
()
# Check arguments
if
finetuning_args
.
stage
!=
"pt"
and
data_args
.
template
is
None
:
raise
ValueError
(
"Please specify which `template` to use."
)
if
finetuning_args
.
stage
!=
"sft"
:
if
training_args
.
predict_with_generate
:
raise
ValueError
(
"`predict_with_generate` cannot be set as True except SFT."
)
...
...
@@ -396,9 +400,6 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
_set_transformers_logging
()
if
data_args
.
template
is
None
:
raise
ValueError
(
"Please specify which `template` to use."
)
if
model_args
.
infer_backend
==
"vllm"
:
if
finetuning_args
.
stage
!=
"sft"
:
raise
ValueError
(
"vLLM engine only supports auto-regressive models."
)
...
...
@@ -429,9 +430,6 @@ def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _E
_set_transformers_logging
()
if
data_args
.
template
is
None
:
raise
ValueError
(
"Please specify which `template` to use."
)
if
model_args
.
infer_backend
==
"vllm"
:
raise
ValueError
(
"vLLM backend is only available for API, CLI and Web."
)
...
...
src/llamafactory/hparams/training_args.py
View file @
317a82e2
...
...
@@ -16,7 +16,11 @@ class RayArguments:
ray_run_name
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The training results will be saved at `saves/ray_run_name`."
},
metadata
=
{
"help"
:
"The training results will be saved at `<ray_storage_path>/ray_run_name`."
},
)
ray_storage_path
:
str
=
field
(
default
=
"./saves"
,
metadata
=
{
"help"
:
"The storage path to save training results to"
},
)
ray_num_workers
:
int
=
field
(
default
=
1
,
...
...
src/llamafactory/launcher.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/model/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment