Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
84987715
Commit
84987715
authored
Apr 07, 2025
by
chenych
Browse files
update to v0.9.2
parent
317a82e2
Changes
58
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
353 additions
and
286 deletions
+353
-286
requirements.txt
requirements.txt
+1
-1
scripts/vllm_infer.py
scripts/vllm_infer.py
+9
-5
setup.py
setup.py
+1
-1
src/llamafactory/api/app.py
src/llamafactory/api/app.py
+3
-2
src/llamafactory/chat/base_engine.py
src/llamafactory/chat/base_engine.py
+2
-0
src/llamafactory/chat/chat_model.py
src/llamafactory/chat/chat_model.py
+3
-3
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+2
-1
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+3
-2
src/llamafactory/cli.py
src/llamafactory/cli.py
+11
-5
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+64
-11
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+1
-1
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+10
-59
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+19
-0
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+10
-1
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+12
-13
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+172
-153
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+5
-4
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+11
-10
src/llamafactory/model/model_utils/rope.py
src/llamafactory/model/model_utils/rope.py
+13
-13
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+1
-1
No files found.
requirements.txt
View file @
84987715
...
...
@@ -5,7 +5,7 @@ accelerate>=0.34.0,<=1.2.1
peft
>=0.11.1,<=0.12.0
trl
>=0.8.6,<=0.9.6
tokenizers
>=0.19.0,<=0.21.0
gradio
>=4.38.0,<=5.1
8
.0
gradio
>=4.38.0,<=5.
2
1.0
pandas
>=2.0.0
scipy
einops
...
...
scripts/vllm_infer.py
View file @
84987715
...
...
@@ -38,7 +38,7 @@ def vllm_infer(
dataset_dir
:
str
=
"data"
,
template
:
str
=
"default"
,
cutoff_len
:
int
=
2048
,
max_samples
:
int
=
None
,
max_samples
:
Optional
[
int
]
=
None
,
vllm_config
:
str
=
"{}"
,
save_name
:
str
=
"generated_predictions.jsonl"
,
temperature
:
float
=
0.95
,
...
...
@@ -46,6 +46,7 @@ def vllm_infer(
top_k
:
int
=
50
,
max_new_tokens
:
int
=
1024
,
repetition_penalty
:
float
=
1.0
,
skip_special_tokens
:
bool
=
True
,
seed
:
Optional
[
int
]
=
None
,
pipeline_parallel_size
:
int
=
1
,
image_max_pixels
:
int
=
768
*
768
,
...
...
@@ -97,19 +98,21 @@ def vllm_infer(
multi_modal_data
=
None
inputs
.
append
({
"prompt_token_ids"
:
sample
[
"input_ids"
],
"multi_modal_data"
:
multi_modal_data
})
prompts
.
append
(
tokenizer
.
decode
(
sample
[
"input_ids"
],
skip_special_tokens
=
False
))
prompts
.
append
(
tokenizer
.
decode
(
sample
[
"input_ids"
],
skip_special_tokens
=
skip_special_tokens
))
labels
.
append
(
tokenizer
.
decode
(
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
sample
[
"labels"
])),
skip_special_tokens
=
False
)
tokenizer
.
decode
(
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
sample
[
"labels"
])),
skip_special_tokens
=
skip_special_tokens
)
)
sampling_params
=
SamplingParams
(
repetition_penalty
=
generating_args
.
repetition_penalty
or
1.0
,
# repetition_penalty must > 0
temperature
=
generating_args
.
temperature
,
top_p
=
generating_args
.
top_p
or
1.0
,
# top_p must > 0
top_k
=
generating_args
.
top_k
,
top_k
=
generating_args
.
top_k
or
-
1
,
# top_k must > 0
stop_token_ids
=
template_obj
.
get_stop_token_ids
(
tokenizer
),
max_tokens
=
generating_args
.
max_new_tokens
,
skip_special_tokens
=
False
,
skip_special_tokens
=
skip_special_tokens
,
seed
=
seed
,
)
if
model_args
.
adapter_name_or_path
is
not
None
:
...
...
@@ -121,6 +124,7 @@ def vllm_infer(
"model"
:
model_args
.
model_name_or_path
,
"trust_remote_code"
:
True
,
"dtype"
:
model_args
.
infer_dtype
,
"max_model_len"
:
cutoff_len
+
max_new_tokens
,
"tensor_parallel_size"
:
(
get_device_count
()
//
pipeline_parallel_size
)
or
1
,
"pipeline_parallel_size"
:
pipeline_parallel_size
,
"disable_log_stats"
:
True
,
...
...
setup.py
View file @
84987715
...
...
@@ -46,7 +46,7 @@ extra_require = {
"torch"
:
[
"torch>=1.13.1"
],
"torch-npu"
:
[
"torch==2.4.0"
,
"torch-npu==2.4.0.post2"
,
"decorator"
],
"metrics"
:
[
"nltk"
,
"jieba"
,
"rouge-chinese"
],
"deepspeed"
:
[
"deepspeed>=0.10.0,<=0.16.
2
"
],
"deepspeed"
:
[
"deepspeed>=0.10.0,<=0.16.
4
"
],
"liger-kernel"
:
[
"liger-kernel"
],
"bitsandbytes"
:
[
"bitsandbytes>=0.39.0"
],
"hqq"
:
[
"hqq"
],
...
...
src/llamafactory/api/app.py
View file @
84987715
...
...
@@ -21,6 +21,7 @@ from typing import Optional
from
typing_extensions
import
Annotated
from
..chat
import
ChatModel
from
..extras.constants
import
EngineName
from
..extras.misc
import
torch_gc
from
..extras.packages
import
is_fastapi_available
,
is_starlette_available
,
is_uvicorn_available
from
.chat
import
(
...
...
@@ -60,7 +61,7 @@ async def sweeper() -> None:
@
asynccontextmanager
async
def
lifespan
(
app
:
"FastAPI"
,
chat_model
:
"ChatModel"
):
# collects GPU memory
if
chat_model
.
engine
_typ
e
==
"huggingface"
:
if
chat_model
.
engine
.
nam
e
==
EngineName
.
HF
:
asyncio
.
create_task
(
sweeper
())
yield
...
...
@@ -106,7 +107,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if
request
.
stream
:
generate
=
create_stream_chat_completion_response
(
request
,
chat_model
)
return
EventSourceResponse
(
generate
,
media_type
=
"text/event-stream"
)
return
EventSourceResponse
(
generate
,
media_type
=
"text/event-stream"
,
sep
=
"
\n
"
)
else
:
return
await
create_chat_completion_response
(
request
,
chat_model
)
...
...
src/llamafactory/chat/base_engine.py
View file @
84987715
...
...
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
from
..data
import
Template
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
..extras.constants
import
EngineName
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
...
...
@@ -41,6 +42,7 @@ class BaseEngine(ABC):
Must implements async methods: chat(), stream_chat() and get_scores().
"""
name
:
"EngineName"
model
:
Union
[
"PreTrainedModel"
,
"AsyncLLMEngine"
]
tokenizer
:
"PreTrainedTokenizer"
can_generate
:
bool
...
...
src/llamafactory/chat/chat_model.py
View file @
84987715
...
...
@@ -20,6 +20,7 @@ import os
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Dict
,
Generator
,
List
,
Optional
,
Sequence
from
..extras.constants
import
EngineName
from
..extras.misc
import
torch_gc
from
..hparams
import
get_infer_args
from
.hf_engine
import
HuggingfaceEngine
...
...
@@ -47,10 +48,9 @@ class ChatModel:
def
__init__
(
self
,
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
self
.
engine_type
=
model_args
.
infer_backend
if
model_args
.
infer_backend
==
"huggingface"
:
if
model_args
.
infer_backend
==
EngineName
.
HF
:
self
.
engine
:
"BaseEngine"
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
"vllm"
:
elif
model_args
.
infer_backend
==
EngineName
.
VLLM
:
self
.
engine
:
"BaseEngine"
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
else
:
raise
NotImplementedError
(
f
"Unknown backend:
{
model_args
.
infer_backend
}
"
)
...
...
src/llamafactory/chat/hf_engine.py
View file @
84987715
...
...
@@ -24,7 +24,7 @@ from typing_extensions import override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
,
EngineName
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
...
...
@@ -50,6 +50,7 @@ class HuggingfaceEngine(BaseEngine):
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
name
=
EngineName
.
HF
self
.
can_generate
=
finetuning_args
.
stage
==
"sft"
tokenizer_module
=
load_tokenizer
(
model_args
)
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/chat/vllm_engine.py
View file @
84987715
...
...
@@ -19,7 +19,7 @@ from typing_extensions import override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
,
EngineName
from
..extras.misc
import
get_device_count
from
..extras.packages
import
is_vllm_available
from
..model
import
load_config
,
load_tokenizer
...
...
@@ -49,6 +49,7 @@ class VllmEngine(BaseEngine):
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
name
=
EngineName
.
VLLM
self
.
model_args
=
model_args
config
=
load_config
(
model_args
)
# may download model from ms hub
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
...
...
@@ -169,7 +170,7 @@ class VllmEngine(BaseEngine):
or
1.0
,
# repetition_penalty must > 0
temperature
=
temperature
if
temperature
is
not
None
else
self
.
generating_args
[
"temperature"
],
top_p
=
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
top_k
=
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
]
,
top_k
=
(
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
]
)
or
-
1
,
# top_k must > 0
stop
=
stop
,
stop_token_ids
=
self
.
template
.
get_stop_token_ids
(
self
.
tokenizer
),
max_tokens
=
max_tokens
,
...
...
src/llamafactory/cli.py
View file @
84987715
...
...
@@ -88,18 +88,24 @@ def main():
elif
command
==
Command
.
TRAIN
:
force_torchrun
=
is_env_enabled
(
"FORCE_TORCHRUN"
)
if
force_torchrun
or
(
get_device_count
()
>
1
and
not
use_ray
()):
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
logger
.
info_rank0
(
f
"Initializing distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
if
int
(
nnodes
)
>
1
:
print
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
process
=
subprocess
.
run
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.
format
(
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
,
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
,
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
,
nnodes
=
nnodes
,
node_rank
=
node_rank
,
nproc_per_node
=
nproc_per_node
,
master_addr
=
master_addr
,
master_port
=
master_port
,
file_name
=
launcher
.
__file__
,
...
...
@@ -119,7 +125,7 @@ def main():
elif
command
==
Command
.
HELP
:
print
(
USAGE
)
else
:
raise
NotImplementedError
(
f
"Unknown command:
{
command
}
."
)
print
(
f
"Unknown command:
{
command
}
.
\n
{
USAGE
}
"
)
if
__name__
==
"__main__"
:
...
...
src/llamafactory/data/data_utils.py
View file @
84987715
...
...
@@ -43,7 +43,7 @@ class Role(str, Enum):
class
DatasetModule
(
TypedDict
):
train_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
Dict
[
str
,
"Dataset"
]
]]
def
merge_dataset
(
...
...
@@ -54,11 +54,13 @@ def merge_dataset(
"""
if
len
(
all_datasets
)
==
1
:
return
all_datasets
[
0
]
elif
data_args
.
mix_strategy
==
"concat"
:
if
data_args
.
streaming
:
logger
.
warning_rank0_once
(
"The samples between different datasets will not be mixed in streaming mode."
)
return
concatenate_datasets
(
all_datasets
)
elif
data_args
.
mix_strategy
.
startswith
(
"interleave"
):
if
not
data_args
.
streaming
:
logger
.
warning_rank0_once
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
...
...
@@ -69,24 +71,75 @@ def merge_dataset(
seed
=
seed
,
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
)
else
:
raise
ValueError
(
f
"Unknown mixing strategy:
{
data_args
.
mix_strategy
}
."
)
def
split_dataset
(
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
data_args
:
"DataArguments"
,
seed
:
int
dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]],
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
Dict
[
str
,
"Dataset"
]]],
data_args
:
"DataArguments"
,
seed
:
int
,
)
->
"DatasetDict"
:
r
"""
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
"""
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
val_set
=
dataset
.
take
(
int
(
data_args
.
val_size
))
train_set
=
dataset
.
skip
(
int
(
data_args
.
val_size
))
return
DatasetDict
({
"train"
:
train_set
,
"validation"
:
val_set
})
else
:
val_size
=
int
(
data_args
.
val_size
)
if
data_args
.
val_size
>
1
else
data_args
.
val_size
dataset
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
return
DatasetDict
({
"train"
:
dataset
[
"train"
],
"validation"
:
dataset
[
"test"
]})
if
eval_dataset
is
not
None
and
data_args
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `eval_dataset` is not None."
)
dataset_dict
=
{}
if
dataset
is
not
None
:
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
if
data_args
.
val_size
>
1e-6
:
if
data_args
.
streaming
:
dataset_dict
[
"validation"
]
=
dataset
.
take
(
int
(
data_args
.
val_size
))
dataset_dict
[
"train"
]
=
dataset
.
skip
(
int
(
data_args
.
val_size
))
else
:
val_size
=
int
(
data_args
.
val_size
)
if
data_args
.
val_size
>
1
else
data_args
.
val_size
dataset_dict
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
dataset
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
dataset_dict
=
{
"train"
:
dataset
[
"train"
],
"validation"
:
dataset
[
"test"
]}
else
:
dataset_dict
[
"train"
]
=
dataset
if
eval_dataset
is
not
None
:
if
isinstance
(
eval_dataset
,
dict
):
dataset_dict
.
update
({
f
"validation_
{
name
}
"
:
data
for
name
,
data
in
eval_dataset
.
items
()})
else
:
if
data_args
.
streaming
:
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
dataset_dict
[
"validation"
]
=
eval_dataset
return
DatasetDict
(
dataset_dict
)
def
get_dataset_module
(
dataset
:
Union
[
"Dataset"
,
"DatasetDict"
])
->
"DatasetModule"
:
r
"""
Converts dataset or dataset dict to dataset module.
"""
dataset_module
:
"DatasetModule"
=
{}
if
isinstance
(
dataset
,
DatasetDict
):
# dataset dict
if
"train"
in
dataset
:
dataset_module
[
"train_dataset"
]
=
dataset
[
"train"
]
if
"validation"
in
dataset
:
dataset_module
[
"eval_dataset"
]
=
dataset
[
"validation"
]
else
:
eval_dataset
=
{}
for
key
in
dataset
.
keys
():
if
key
.
startswith
(
"validation_"
):
eval_dataset
[
key
[
len
(
"validation_"
)
:]]
=
dataset
[
key
]
if
len
(
eval_dataset
):
dataset_module
[
"eval_dataset"
]
=
eval_dataset
else
:
# single dataset
dataset_module
[
"train_dataset"
]
=
dataset
return
dataset_module
src/llamafactory/data/formatter.py
View file @
84987715
...
...
@@ -121,7 +121,7 @@ class FunctionFormatter(StringFormatter):
function_str
=
self
.
tool_utils
.
function_formatter
(
functions
)
if
thought
:
function_str
=
thought
.
group
(
1
)
+
function_str
function_str
=
thought
.
group
(
0
)
+
function_str
return
super
().
apply
(
content
=
function_str
)
...
...
src/llamafactory/data/loader.py
View file @
84987715
...
...
@@ -13,17 +13,16 @@
# limitations under the License.
import
os
import
sys
from
typing
import
TYPE_CHECKING
,
Dict
,
Literal
,
Optional
,
Sequence
,
Union
import
numpy
as
np
from
datasets
import
DatasetDict
,
load_dataset
,
load_from_disk
from
datasets
import
load_dataset
,
load_from_disk
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.misc
import
check_version
,
has_tokenized_data
from
.converter
import
align_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.data_utils
import
get_dataset_module
,
merge_dataset
,
split_dataset
from
.parser
import
get_dataset_list
from
.processor
import
(
FeedbackDatasetProcessor
,
...
...
@@ -292,23 +291,12 @@ def get_dataset(
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
logger
.
warning_rank0
(
"Loading dataset from disk will ignore other data arguments."
)
tokenized_data
:
Union
[
"Dataset"
,
"DatasetDict"
]
=
load_from_disk
(
data_args
.
tokenized_path
)
logger
.
info_rank0
(
f
"Loaded tokenized dataset from
{
data_args
.
tokenized_path
}
."
)
dataset_module
:
Dict
[
str
,
"Dataset"
]
=
{}
if
isinstance
(
tokenized_data
,
DatasetDict
):
if
"train"
in
tokenized_data
:
dataset_module
[
"train_dataset"
]
=
tokenized_data
[
"train"
]
if
"validation"
in
tokenized_data
:
dataset_module
[
"eval_dataset"
]
=
tokenized_data
[
"validation"
]
else
:
# single dataset
dataset_module
[
"train_dataset"
]
=
tokenized_data
tokenized_data
=
load_from_disk
(
data_args
.
tokenized_path
)
dataset_module
=
get_dataset_module
(
tokenized_data
)
if
data_args
.
streaming
:
dataset_module
=
{
k
:
v
.
to_iterable_dataset
()
for
k
,
v
in
dataset_module
.
items
()
}
dataset_module
[
"train_dataset"
]
=
dataset_module
[
"train_dataset"
].
to_iterable_dataset
()
logger
.
info_rank0
(
f
"Loaded tokenized dataset from
{
data_args
.
tokenized_path
}
."
)
return
dataset_module
if
data_args
.
streaming
:
...
...
@@ -335,48 +323,11 @@ def get_dataset(
eval_dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
if
data_args
.
val_size
>
1e-6
:
dataset_dict
=
split_dataset
(
dataset
,
data_args
,
seed
=
training_args
.
seed
)
else
:
dataset_dict
=
{}
if
dataset
is
not
None
:
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
dataset_dict
[
"train"
]
=
dataset
if
eval_dataset
is
not
None
:
if
isinstance
(
eval_dataset
,
dict
):
dataset_dict
.
update
({
f
"validation_
{
name
}
"
:
data
for
name
,
data
in
eval_dataset
.
items
()})
else
:
if
data_args
.
streaming
:
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
dataset_dict
[
"validation"
]
=
eval_dataset
dataset_dict
=
DatasetDict
(
dataset_dict
)
if
data_args
.
tokenized_path
is
not
None
:
# save tokenized dataset to disk and exit
dataset_dict
=
split_dataset
(
dataset
,
eval_dataset
,
data_args
,
seed
=
training_args
.
seed
)
if
data_args
.
tokenized_path
is
not
None
:
# save tokenized dataset to disk
if
training_args
.
should_save
:
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
logger
.
info_rank0
(
f
"Tokenized dataset is saved at
{
data_args
.
tokenized_path
}
."
)
logger
.
info_rank0
(
f
"Please restart the training with `tokenized_path:
{
data_args
.
tokenized_path
}
`."
)
sys
.
exit
(
0
)
dataset_module
=
{}
if
"train"
in
dataset_dict
:
dataset_module
[
"train_dataset"
]
=
dataset_dict
[
"train"
]
if
"validation"
in
dataset_dict
:
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
else
:
eval_dataset
=
{}
for
key
in
dataset_dict
.
keys
():
if
key
.
startswith
(
"validation_"
):
eval_dataset
[
key
[
len
(
"validation_"
)
:]]
=
dataset_dict
[
key
]
if
len
(
eval_dataset
):
dataset_module
[
"eval_dataset"
]
=
eval_dataset
logger
.
info_rank0
(
f
"Please launch the training with `tokenized_path:
{
data_args
.
tokenized_path
}
`."
)
return
dataset_module
return
get_
dataset_module
(
dataset_dict
)
src/llamafactory/extras/constants.py
View file @
84987715
...
...
@@ -96,12 +96,31 @@ V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME
=
"value_head.safetensors"
class
AttentionFunction
(
str
,
Enum
):
AUTO
=
"auto"
DISABLED
=
"disabled"
SDPA
=
"sdpa"
FA2
=
"fa2"
class
EngineName
(
str
,
Enum
):
HF
=
"huggingface"
VLLM
=
"vllm"
class
DownloadSource
(
str
,
Enum
):
DEFAULT
=
"hf"
MODELSCOPE
=
"ms"
OPENMIND
=
"om"
class
RopeScaling
(
str
,
Enum
):
LINEAR
=
"linear"
DYNAMIC
=
"dynamic"
YARN
=
"yarn"
LLAMA3
=
"llama3"
def
register_model_group
(
models
:
Dict
[
str
,
Dict
[
DownloadSource
,
str
]],
template
:
Optional
[
str
]
=
None
,
...
...
src/llamafactory/extras/env.py
View file @
84987715
...
...
@@ -26,7 +26,7 @@ import trl
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
VERSION
=
"0.9.2
.dev0
"
VERSION
=
"0.9.2"
def
print_env
()
->
None
:
...
...
@@ -74,4 +74,13 @@ def print_env() -> None:
except
Exception
:
pass
try
:
import
subprocess
commit_info
=
subprocess
.
run
([
"git"
,
"rev-parse"
,
"HEAD"
],
capture_output
=
True
,
text
=
True
,
check
=
True
)
commit_hash
=
commit_info
.
stdout
.
strip
()
info
[
"Git commit"
]
=
commit_hash
except
Exception
:
pass
print
(
"
\n
"
+
"
\n
"
.
join
([
f
"-
{
key
}
:
{
value
}
"
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
src/llamafactory/hparams/finetuning_args.py
View file @
84987715
...
...
@@ -363,15 +363,15 @@ class SwanLabArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the SwanLab (an experiment tracking and visualization tool)."
},
)
swanlab_project
:
str
=
field
(
swanlab_project
:
Optional
[
str
]
=
field
(
default
=
"llamafactory"
,
metadata
=
{
"help"
:
"The project name in SwanLab."
},
)
swanlab_workspace
:
str
=
field
(
swanlab_workspace
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The workspace name in SwanLab."
},
)
swanlab_run_name
:
str
=
field
(
swanlab_run_name
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The experiment name in SwanLab."
},
)
...
...
@@ -379,15 +379,19 @@ class SwanLabArguments:
default
=
"cloud"
,
metadata
=
{
"help"
:
"The mode of SwanLab."
},
)
swanlab_api_key
:
str
=
field
(
swanlab_api_key
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The API key for SwanLab."
},
)
swanlab_logdir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The log directory for SwanLab."
},
)
@
dataclass
class
FinetuningArguments
(
Freeze
Arguments
,
Lora
Argument
s
,
RLHF
Arguments
,
GaloreArguments
,
Apollo
Arguments
,
BAdam
Argument
,
SwanLab
Arguments
SwanLab
Arguments
,
BAdam
Argument
,
Apollo
Arguments
,
GaloreArguments
,
RLHF
Arguments
,
Lora
Argument
s
,
Freeze
Arguments
):
r
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
...
...
@@ -415,15 +419,15 @@ class FinetuningArguments(
)
freeze_vision_tower
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether ot not to freeze vision tower in MLLM training."
},
metadata
=
{
"help"
:
"Whether ot not to freeze
the
vision tower in MLLM training."
},
)
freeze_multi_modal_projector
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to freeze the multi modal projector in MLLM training."
},
)
train_mm_proj_only
:
bool
=
field
(
freeze_language_model
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to
train the multimodal projector for MLLM only
."
},
metadata
=
{
"help"
:
"Whether or not to
freeze the language model in MLLM training
."
},
)
compute_accuracy
:
bool
=
field
(
default
=
False
,
...
...
@@ -455,8 +459,6 @@ class FinetuningArguments(
self
.
additional_target
:
Optional
[
List
[
str
]]
=
split_arg
(
self
.
additional_target
)
self
.
galore_target
:
List
[
str
]
=
split_arg
(
self
.
galore_target
)
self
.
apollo_target
:
List
[
str
]
=
split_arg
(
self
.
apollo_target
)
self
.
freeze_vision_tower
=
self
.
freeze_vision_tower
or
self
.
train_mm_proj_only
self
.
freeze_multi_modal_projector
=
self
.
freeze_multi_modal_projector
and
not
self
.
train_mm_proj_only
self
.
use_ref_model
=
self
.
stage
==
"dpo"
and
self
.
pref_loss
not
in
[
"orpo"
,
"simpo"
]
assert
self
.
finetuning_type
in
[
"lora"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
...
...
@@ -484,9 +486,6 @@ class FinetuningArguments(
if
self
.
pissa_init
and
(
self
.
stage
in
[
"ppo"
,
"kto"
]
or
self
.
use_ref_model
):
raise
ValueError
(
"Cannot use PiSSA for current training stage."
)
if
self
.
train_mm_proj_only
and
self
.
finetuning_type
!=
"full"
:
raise
ValueError
(
"`train_mm_proj_only` is only valid for full training."
)
if
self
.
finetuning_type
!=
"lora"
:
if
self
.
loraplus_lr_ratio
is
not
None
:
raise
ValueError
(
"`loraplus_lr_ratio` is only valid for LoRA training."
)
...
...
src/llamafactory/hparams/model_args.py
View file @
84987715
...
...
@@ -23,6 +23,164 @@ import torch
from
transformers.training_args
import
_convert_str_dict
from
typing_extensions
import
Self
from
..extras.constants
import
AttentionFunction
,
EngineName
,
RopeScaling
@
dataclass
class
BaseModelArguments
:
r
"""
Arguments pertaining to the model.
"""
model_name_or_path
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
)
adapter_folder
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The folder containing the adapter weights to load."
},
)
cache_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."
},
)
use_fast_tokenizer
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."
},
)
resize_vocab
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to resize the tokenizer vocab and the embedding layers."
},
)
split_special_tokens
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not the special tokens should be split during the tokenization process."
},
)
new_special_tokens
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
model_revision
:
str
=
field
(
default
=
"main"
,
metadata
=
{
"help"
:
"The specific model version to use (can be a branch name, tag name or commit id)."
},
)
low_cpu_mem_usage
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use memory-efficient model loading."
},
)
rope_scaling
:
Optional
[
RopeScaling
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Which scaling strategy should be adopted for the RoPE embeddings."
},
)
flash_attn
:
AttentionFunction
=
field
(
default
=
AttentionFunction
.
AUTO
,
metadata
=
{
"help"
:
"Enable FlashAttention for faster training and inference."
},
)
shift_attn
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Enable shift short attention (S^2-Attn) proposed by LongLoRA."
},
)
mixture_of_depths
:
Optional
[
Literal
[
"convert"
,
"load"
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Convert the model to mixture-of-depths (MoD) or load the MoD model."
},
)
use_unsloth
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use unsloth's optimization for the LoRA training."
},
)
use_unsloth_gc
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use unsloth's gradient checkpointing (no need to install unsloth)."
},
)
enable_liger_kernel
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to enable liger kernel for faster training."
},
)
moe_aux_loss_coef
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Coefficient of the auxiliary router loss in mixture-of-experts model."
},
)
disable_gradient_checkpointing
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to disable gradient checkpointing."
},
)
use_reentrant_gc
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use reentrant gradient checkpointing."
},
)
upcast_layernorm
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upcast the layernorm weights in fp32."
},
)
upcast_lmhead_output
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upcast the output of lm_head in fp32."
},
)
train_from_scratch
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to randomly initialize the model weights."
},
)
infer_backend
:
EngineName
=
field
(
default
=
EngineName
.
HF
,
metadata
=
{
"help"
:
"Backend engine used at inference."
},
)
offload_folder
:
str
=
field
(
default
=
"offload"
,
metadata
=
{
"help"
:
"Path to offload model weights."
},
)
use_cache
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use KV cache in generation."
},
)
infer_dtype
:
Literal
[
"auto"
,
"float16"
,
"bfloat16"
,
"float32"
]
=
field
(
default
=
"auto"
,
metadata
=
{
"help"
:
"Data type for model weights and activations at inference."
},
)
hf_hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with Hugging Face Hub."
},
)
ms_hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with ModelScope Hub."
},
)
om_hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with Modelers Hub."
},
)
print_param_status
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"For debugging purposes, print the status of the parameters in the model."
},
)
trust_remote_code
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to trust the execution of code from datasets/models defined on the Hub or not."
},
)
def
__post_init__
(
self
):
if
self
.
model_name_or_path
is
None
:
raise
ValueError
(
"Please provide `model_name_or_path`."
)
if
self
.
split_special_tokens
and
self
.
use_fast_tokenizer
:
raise
ValueError
(
"`split_special_tokens` is only supported for slow tokenizers."
)
if
self
.
adapter_name_or_path
is
not
None
:
# support merging multiple lora weights
self
.
adapter_name_or_path
=
[
path
.
strip
()
for
path
in
self
.
adapter_name_or_path
.
split
(
","
)]
if
self
.
new_special_tokens
is
not
None
:
# support multiple special tokens
self
.
new_special_tokens
=
[
token
.
strip
()
for
token
in
self
.
new_special_tokens
.
split
(
","
)]
@
dataclass
class
QuantizationArguments
:
...
...
@@ -127,6 +285,10 @@ class ExportArguments:
metadata
=
{
"help"
:
"The name of the repository if push the model to the Hugging Face hub."
},
)
def
__post_init__
(
self
):
if
self
.
export_quantization_bit
is
not
None
and
self
.
export_quantization_dataset
is
None
:
raise
ValueError
(
"Quantization dataset is necessary for exporting."
)
@
dataclass
class
VllmArguments
:
...
...
@@ -155,148 +317,19 @@ class VllmArguments:
metadata
=
{
"help"
:
"Config to initialize the vllm engine. Please use JSON strings."
},
)
def
__post_init__
(
self
):
if
isinstance
(
self
.
vllm_config
,
str
)
and
self
.
vllm_config
.
startswith
(
"{"
):
self
.
vllm_config
=
_convert_str_dict
(
json
.
loads
(
self
.
vllm_config
))
@
dataclass
class
ModelArguments
(
Quantization
Arguments
,
ProcessorArguments
,
Export
Arguments
,
Vllm
Arguments
):
class
ModelArguments
(
VllmArguments
,
Export
Arguments
,
ProcessorArguments
,
Quantization
Arguments
,
BaseModel
Arguments
):
r
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
model_name_or_path
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
)
adapter_folder
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The folder containing the adapter weights to load."
},
)
cache_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."
},
)
use_fast_tokenizer
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."
},
)
resize_vocab
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to resize the tokenizer vocab and the embedding layers."
},
)
split_special_tokens
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not the special tokens should be split during the tokenization process."
},
)
new_special_tokens
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
model_revision
:
str
=
field
(
default
=
"main"
,
metadata
=
{
"help"
:
"The specific model version to use (can be a branch name, tag name or commit id)."
},
)
low_cpu_mem_usage
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use memory-efficient model loading."
},
)
rope_scaling
:
Optional
[
Literal
[
"linear"
,
"dynamic"
,
"yarn"
,
"llama3"
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Which scaling strategy should be adopted for the RoPE embeddings."
},
)
flash_attn
:
Literal
[
"auto"
,
"disabled"
,
"sdpa"
,
"fa2"
]
=
field
(
default
=
"auto"
,
metadata
=
{
"help"
:
"Enable FlashAttention for faster training and inference."
},
)
shift_attn
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Enable shift short attention (S^2-Attn) proposed by LongLoRA."
},
)
mixture_of_depths
:
Optional
[
Literal
[
"convert"
,
"load"
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Convert the model to mixture-of-depths (MoD) or load the MoD model."
},
)
use_unsloth
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use unsloth's optimization for the LoRA training."
},
)
use_unsloth_gc
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use unsloth's gradient checkpointing."
},
)
enable_liger_kernel
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to enable liger kernel for faster training."
},
)
moe_aux_loss_coef
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Coefficient of the auxiliary router loss in mixture-of-experts model."
},
)
disable_gradient_checkpointing
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to disable gradient checkpointing."
},
)
use_reentrant_gc
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use reentrant gradient checkpointing."
},
)
upcast_layernorm
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upcast the layernorm weights in fp32."
},
)
upcast_lmhead_output
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upcast the output of lm_head in fp32."
},
)
train_from_scratch
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to randomly initialize the model weights."
},
)
infer_backend
:
Literal
[
"huggingface"
,
"vllm"
]
=
field
(
default
=
"huggingface"
,
metadata
=
{
"help"
:
"Backend engine used at inference."
},
)
offload_folder
:
str
=
field
(
default
=
"offload"
,
metadata
=
{
"help"
:
"Path to offload model weights."
},
)
use_cache
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use KV cache in generation."
},
)
infer_dtype
:
Literal
[
"auto"
,
"float16"
,
"bfloat16"
,
"float32"
]
=
field
(
default
=
"auto"
,
metadata
=
{
"help"
:
"Data type for model weights and activations at inference."
},
)
hf_hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with Hugging Face Hub."
},
)
ms_hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with ModelScope Hub."
},
)
om_hub_token
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with Modelers Hub."
},
)
print_param_status
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"For debugging purposes, print the status of the parameters in the model."
},
)
trust_remote_code
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to trust the execution of code from datasets/models defined on the Hub or not."
},
)
compute_dtype
:
Optional
[
torch
.
dtype
]
=
field
(
default
=
None
,
init
=
False
,
...
...
@@ -319,23 +352,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
)
def
__post_init__
(
self
):
if
self
.
model_name_or_path
is
None
:
raise
ValueError
(
"Please provide `model_name_or_path`."
)
if
self
.
split_special_tokens
and
self
.
use_fast_tokenizer
:
raise
ValueError
(
"`split_special_tokens` is only supported for slow tokenizers."
)
if
self
.
adapter_name_or_path
is
not
None
:
# support merging multiple lora weights
self
.
adapter_name_or_path
=
[
path
.
strip
()
for
path
in
self
.
adapter_name_or_path
.
split
(
","
)]
if
self
.
new_special_tokens
is
not
None
:
# support multiple special tokens
self
.
new_special_tokens
=
[
token
.
strip
()
for
token
in
self
.
new_special_tokens
.
split
(
","
)]
if
self
.
export_quantization_bit
is
not
None
and
self
.
export_quantization_dataset
is
None
:
raise
ValueError
(
"Quantization dataset is necessary for exporting."
)
if
isinstance
(
self
.
vllm_config
,
str
)
and
self
.
vllm_config
.
startswith
(
"{"
):
self
.
vllm_config
=
_convert_str_dict
(
json
.
loads
(
self
.
vllm_config
))
BaseModelArguments
.
__post_init__
(
self
)
ExportArguments
.
__post_init__
(
self
)
VllmArguments
.
__post_init__
(
self
)
@
classmethod
def
copyfrom
(
cls
,
source
:
"Self"
,
**
kwargs
)
->
"Self"
:
...
...
src/llamafactory/hparams/parser.py
View file @
84987715
...
...
@@ -382,10 +382,10 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary
logger
.
info
(
"Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}"
.
format
(
training_args
.
local_rank
,
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}"
.
format
(
training_args
.
process_index
,
training_args
.
world_size
,
training_args
.
device
,
training_args
.
n_gpu
,
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
,
str
(
model_args
.
compute_dtype
),
)
...
...
@@ -418,7 +418,8 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
if
model_args
.
export_dir
is
not
None
and
model_args
.
export_device
==
"cpu"
:
model_args
.
device_map
=
{
""
:
torch
.
device
(
"cpu"
)}
model_args
.
model_max_length
=
data_args
.
cutoff_len
if
data_args
.
cutoff_len
!=
DataArguments
().
cutoff_len
:
# override cutoff_len if it is not default
model_args
.
model_max_length
=
data_args
.
cutoff_len
else
:
model_args
.
device_map
=
"auto"
...
...
src/llamafactory/model/model_utils/attention.py
View file @
84987715
...
...
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING
from
transformers.utils
import
is_flash_attn_2_available
,
is_torch_sdpa_available
from
...extras
import
logging
from
...extras.constants
import
AttentionFunction
from
...extras.misc
import
check_version
...
...
@@ -33,34 +34,34 @@ def configure_attn_implementation(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
if
model_args
.
flash_attn
==
"auto"
or
model_args
.
flash_attn
==
"fa2"
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
is_flash_attn_2_available
():
check_version
(
"transformers>=4.42.4"
)
check_version
(
"flash_attn>=2.6.3"
)
if
model_args
.
flash_attn
!=
"fa2"
:
logger
.
warning_rank0
(
"Gemma
-
2 should use flash attention 2, change `flash_attn` to fa2."
)
model_args
.
flash_attn
=
"fa2"
if
model_args
.
flash_attn
!=
AttentionFunction
.
FA2
:
logger
.
warning_rank0
(
"Gemma
2 should use flash attention 2, change `flash_attn` to fa2."
)
model_args
.
flash_attn
=
AttentionFunction
.
FA2
else
:
logger
.
warning_rank0
(
"FlashAttention-2 is not installed, use eager attention."
)
model_args
.
flash_attn
=
"disabled"
elif
model_args
.
flash_attn
==
"sdpa"
:
model_args
.
flash_attn
=
AttentionFunction
.
DISABLED
elif
model_args
.
flash_attn
==
AttentionFunction
.
SDPA
:
logger
.
warning_rank0
(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if
model_args
.
flash_attn
==
"auto"
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
:
return
elif
model_args
.
flash_attn
==
"disabled"
:
elif
model_args
.
flash_attn
==
AttentionFunction
.
DISABLED
:
requested_attn_implementation
=
"eager"
elif
model_args
.
flash_attn
==
"sdpa"
:
elif
model_args
.
flash_attn
==
AttentionFunction
.
SDPA
:
if
not
is_torch_sdpa_available
():
logger
.
warning_rank0
(
"torch>=2.1.1 is required for SDPA attention."
)
return
requested_attn_implementation
=
"sdpa"
elif
model_args
.
flash_attn
==
"fa2"
:
elif
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
not
is_flash_attn_2_available
():
logger
.
warning_rank0
(
"FlashAttention-2 is not installed."
)
return
...
...
src/llamafactory/model/model_utils/rope.py
View file @
84987715
...
...
@@ -20,6 +20,7 @@ import math
from
typing
import
TYPE_CHECKING
from
...extras
import
logging
from
...extras.constants
import
RopeScaling
if
TYPE_CHECKING
:
...
...
@@ -39,33 +40,32 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger
.
warning_rank0
(
"Current model does not support RoPE scaling."
)
return
rope_kwargs
=
{
}
rope_kwargs
=
{
"rope_type"
:
getattr
(
model_args
.
rope_scaling
,
"value"
,
model_args
.
rope_scaling
)}
# handle enum
if
model_args
.
model_max_length
is
not
None
:
if
is_trainable
and
model_args
.
rope_scaling
==
"dynamic"
:
if
is_trainable
and
model_args
.
rope_scaling
==
RopeScaling
.
DYNAMIC
:
logger
.
warning_rank0
(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
if
current_max_length
and
model_args
.
model_max_length
>
current_max_length
:
logger
.
info_rank0
(
f
"Enlarge max model length from
{
current_max_length
}
to
{
model_args
.
model_max_length
}
."
)
setattr
(
config
,
"max_position_embeddings"
,
model_args
.
model_max_length
)
rope_kwargs
[
"factor"
]
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
current_max_length
))
else
:
logger
.
warning_rank0
(
"Input length is smaller than max length. Consider increase input length."
)
rope_kwargs
[
"factor"
]
=
1.0
if
(
not
current_max_length
)
or
model_args
.
model_max_length
<=
current_max_length
:
logger
.
warning_rank0
(
"Input length is smaller than max length. Disabling rope scaling."
)
return
if
model_args
.
rope_scaling
==
"dynamic"
:
logger
.
info_rank0
(
f
"Enlarge max model length from
{
current_max_length
}
to
{
model_args
.
model_max_length
}
."
)
setattr
(
config
,
"max_position_embeddings"
,
model_args
.
model_max_length
)
rope_kwargs
[
"factor"
]
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
current_max_length
))
if
model_args
.
rope_scaling
==
RopeScaling
.
DYNAMIC
:
rope_kwargs
[
"original_max_position_embeddings"
]
=
current_max_length
elif
model_args
.
rope_scaling
==
"llama3"
:
elif
model_args
.
rope_scaling
==
RopeScaling
.
LLAMA3
:
rope_kwargs
[
"original_max_position_embeddings"
]
=
current_max_length
rope_kwargs
[
"low_freq_factor"
]
=
1.0
rope_kwargs
[
"high_freq_factor"
]
=
4.0
else
:
rope_kwargs
[
"factor"
]
=
2.0
setattr
(
config
,
"rope_scaling"
,
{
"rope_type"
:
model_args
.
rope_scaling
,
**
rope_kwargs
}
)
setattr
(
config
,
"rope_scaling"
,
rope_kwargs
)
logger
.
info_rank0
(
f
"Using
{
model_
args
.
rope_
scaling
}
scaling strategy and setting scaling factor to
{
rope_kwargs
[
'factor'
]
}
."
f
"Using
{
rope_kw
args
[
'
rope_
type'
]
}
scaling strategy and setting scaling factor to
{
rope_kwargs
[
'factor'
]
}
."
)
src/llamafactory/model/model_utils/visual.py
View file @
84987715
...
...
@@ -166,7 +166,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
logger
.
info_rank0
(
f
"Set multi model projector not trainable:
{
projector_key
}
."
)
forbidden_modules
.
add
(
projector_key
)
if
finetuning_args
.
train_mm_proj_only
:
if
finetuning_args
.
freeze_language_model
:
language_model_keys
=
COMPOSITE_MODELS
[
model_type
].
language_model_keys
logger
.
info_rank0
(
f
"Set language model not trainable:
{
language_model_keys
}
."
)
forbidden_modules
.
update
(
language_model_keys
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment