Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
4a40151b
Commit
4a40151b
authored
Nov 05, 2024
by
chenych
Browse files
Update v0.8.3
parent
731cf9b8
Changes
56
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
72 additions
and
190 deletions
+72
-190
scripts/llamafy_baichuan2.py
scripts/llamafy_baichuan2.py
+2
-5
scripts/llamafy_qwen.py
scripts/llamafy_qwen.py
+2
-5
scripts/loftq_init.py
scripts/loftq_init.py
+1
-1
scripts/pissa_init.py
scripts/pissa_init.py
+1
-2
setup.py
setup.py
+0
-1
src/llamafactory/__init__.py
src/llamafactory/__init__.py
+9
-7
src/llamafactory/data/aligner.py
src/llamafactory/data/aligner.py
+3
-3
src/llamafactory/data/processors/feedback.py
src/llamafactory/data/processors/feedback.py
+4
-4
src/llamafactory/data/processors/pairwise.py
src/llamafactory/data/processors/pairwise.py
+5
-4
src/llamafactory/data/processors/supervised.py
src/llamafactory/data/processors/supervised.py
+11
-24
src/llamafactory/data/processors/unsupervised.py
src/llamafactory/data/processors/unsupervised.py
+3
-3
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+6
-22
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+0
-66
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+1
-1
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+7
-9
src/llamafactory/extras/packages.py
src/llamafactory/extras/packages.py
+0
-5
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+0
-3
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+0
-4
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+16
-20
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+1
-1
No files found.
scripts/llamafy_baichuan2.py
View file @
4a40151b
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
json
import
json
import
os
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
from
typing
import
Any
,
Dict
,
Optional
import
fire
import
fire
import
torch
import
torch
...
@@ -86,10 +86,7 @@ def save_config(input_dir: str, output_dir: str):
...
@@ -86,10 +86,7 @@ def save_config(input_dir: str, output_dir: str):
def
llamafy_baichuan2
(
def
llamafy_baichuan2
(
input_dir
:
str
,
input_dir
:
str
,
output_dir
:
str
,
shard_size
:
Optional
[
str
]
=
"2GB"
,
save_safetensors
:
Optional
[
bool
]
=
False
output_dir
:
str
,
shard_size
:
str
=
"2GB"
,
save_safetensors
:
bool
=
True
,
):
):
r
"""
r
"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
...
...
scripts/llamafy_qwen.py
View file @
4a40151b
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
json
import
json
import
os
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
from
typing
import
Any
,
Dict
,
Optional
import
fire
import
fire
import
torch
import
torch
...
@@ -139,10 +139,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
...
@@ -139,10 +139,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def
llamafy_qwen
(
def
llamafy_qwen
(
input_dir
:
str
,
input_dir
:
str
,
output_dir
:
str
,
shard_size
:
Optional
[
str
]
=
"2GB"
,
save_safetensors
:
Optional
[
bool
]
=
False
output_dir
:
str
,
shard_size
:
str
=
"2GB"
,
save_safetensors
:
bool
=
False
,
):
):
r
"""
r
"""
Converts the Qwen models in the same format as LLaMA2.
Converts the Qwen models in the same format as LLaMA2.
...
...
scripts/loftq_init.py
View file @
4a40151b
...
@@ -67,7 +67,7 @@ def quantize_loftq(
...
@@ -67,7 +67,7 @@ def quantize_loftq(
loftq_dir
=
os
.
path
.
join
(
output_dir
,
"loftq_init"
)
loftq_dir
=
os
.
path
.
join
(
output_dir
,
"loftq_init"
)
# Save LoftQ model
# Save LoftQ model
setattr
(
peft_model
.
peft_config
[
"default"
],
"base_model_name_or_path"
,
os
.
path
.
abspath
(
output_dir
)
)
setattr
(
peft_model
.
peft_config
[
"default"
],
"base_model_name_or_path"
,
output_dir
)
setattr
(
peft_model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
# don't apply loftq again
setattr
(
peft_model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
# don't apply loftq again
peft_model
.
save_pretrained
(
loftq_dir
,
safe_serialization
=
save_safetensors
)
peft_model
.
save_pretrained
(
loftq_dir
,
safe_serialization
=
save_safetensors
)
print
(
"Adapter weights saved in {}"
.
format
(
loftq_dir
))
print
(
"Adapter weights saved in {}"
.
format
(
loftq_dir
))
...
...
scripts/pissa_init.py
View file @
4a40151b
...
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
...
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
def
quantize_pissa
(
def
quantize_pissa
(
model_name_or_path
:
str
,
model_name_or_path
:
str
,
output_dir
:
str
,
output_dir
:
str
,
pissa_iter
:
int
=
16
,
pissa_iter
:
int
=
4
,
lora_alpha
:
int
=
None
,
lora_alpha
:
int
=
None
,
lora_rank
:
int
=
16
,
lora_rank
:
int
=
16
,
lora_dropout
:
float
=
0
,
lora_dropout
:
float
=
0
,
...
@@ -62,7 +62,6 @@ def quantize_pissa(
...
@@ -62,7 +62,6 @@ def quantize_pissa(
pissa_dir
=
os
.
path
.
join
(
output_dir
,
"pissa_init"
)
pissa_dir
=
os
.
path
.
join
(
output_dir
,
"pissa_init"
)
# Save PiSSA model
# Save PiSSA model
setattr
(
peft_model
.
peft_config
[
"default"
],
"base_model_name_or_path"
,
os
.
path
.
abspath
(
output_dir
))
setattr
(
peft_model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
# don't apply pissa again
setattr
(
peft_model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
# don't apply pissa again
peft_model
.
save_pretrained
(
pissa_dir
,
safe_serialization
=
save_safetensors
)
peft_model
.
save_pretrained
(
pissa_dir
,
safe_serialization
=
save_safetensors
)
print
(
"Adapter weights saved in {}"
.
format
(
pissa_dir
))
print
(
"Adapter weights saved in {}"
.
format
(
pissa_dir
))
...
...
setup.py
View file @
4a40151b
...
@@ -47,7 +47,6 @@ extra_require = {
...
@@ -47,7 +47,6 @@ extra_require = {
"vllm"
:
[
"vllm>=0.4.3"
],
"vllm"
:
[
"vllm>=0.4.3"
],
"galore"
:
[
"galore-torch"
],
"galore"
:
[
"galore-torch"
],
"badam"
:
[
"badam>=1.2.1"
],
"badam"
:
[
"badam>=1.2.1"
],
"adam-mini"
:
[
"adam-mini"
],
"qwen"
:
[
"transformers_stream_generator"
],
"qwen"
:
[
"transformers_stream_generator"
],
"modelscope"
:
[
"modelscope"
],
"modelscope"
:
[
"modelscope"
],
"dev"
:
[
"ruff"
,
"pytest"
],
"dev"
:
[
"ruff"
,
"pytest"
],
...
...
src/llamafactory/__init__.py
View file @
4a40151b
...
@@ -20,17 +20,19 @@ Level:
...
@@ -20,17 +20,19 @@ Level:
Dependency graph:
Dependency graph:
main:
main:
transformers>=4.41.2
,<=4.43.4
transformers>=4.41.2
datasets>=2.16.0
,<=2.20.0
datasets>=2.16.0
accelerate>=0.30.1
,<=0.32.0
accelerate>=0.30.1
peft>=0.11.1
,<=0.12.0
peft>=0.11.1
trl>=0.8.6
,<=0.9.6
trl>=0.8.6
attention:
attention:
transformers>=4.42.4 (gemma+fa2)
transformers>=4.42.4 (gemma+fa2)
longlora:
longlora:
transformers>=4.41.2,<=4.4
3
.4
transformers>=4.41.2,<=4.4
2
.4
packing:
packing:
transformers>=4.41.2,<=4.43.4
transformers>=4.41.2,<=4.42.4
patcher:
transformers==4.41.2 (chatglm)
"""
"""
from
.cli
import
VERSION
from
.cli
import
VERSION
...
...
src/llamafactory/data/aligner.py
View file @
4a40151b
...
@@ -120,15 +120,15 @@ def convert_sharegpt(
...
@@ -120,15 +120,15 @@ def convert_sharegpt(
even_tags
=
(
dataset_attr
.
assistant_tag
,
dataset_attr
.
function_tag
)
even_tags
=
(
dataset_attr
.
assistant_tag
,
dataset_attr
.
function_tag
)
accept_tags
=
(
odd_tags
,
even_tags
)
accept_tags
=
(
odd_tags
,
even_tags
)
for
i
,
messages
in
enumerate
(
examples
[
dataset_attr
.
messages
]):
for
i
,
messages
in
enumerate
(
examples
[
dataset_attr
.
messages
]):
if
len
(
messages
)
==
0
:
continue
if
dataset_attr
.
system_tag
and
messages
[
0
][
dataset_attr
.
role_tag
]
==
dataset_attr
.
system_tag
:
if
dataset_attr
.
system_tag
and
messages
[
0
][
dataset_attr
.
role_tag
]
==
dataset_attr
.
system_tag
:
system
=
messages
[
0
][
dataset_attr
.
content_tag
]
system
=
messages
[
0
][
dataset_attr
.
content_tag
]
messages
=
messages
[
1
:]
messages
=
messages
[
1
:]
else
:
else
:
system
=
examples
[
dataset_attr
.
system
][
i
]
if
dataset_attr
.
system
else
""
system
=
examples
[
dataset_attr
.
system
][
i
]
if
dataset_attr
.
system
else
""
if
len
(
messages
)
==
0
:
continue
aligned_messages
=
[]
aligned_messages
=
[]
broken_data
=
False
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
for
turn_idx
,
message
in
enumerate
(
messages
):
...
...
src/llamafactory/data/processors/feedback.py
View file @
4a40151b
...
@@ -38,7 +38,7 @@ def _encode_feedback_example(
...
@@ -38,7 +38,7 @@ def _encode_feedback_example(
template
:
"Template"
,
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
data_args
:
"DataArguments"
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
...
@@ -67,10 +67,10 @@ def _encode_feedback_example(
...
@@ -67,10 +67,10 @@ def _encode_feedback_example(
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
kl_prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
kl_prompt_ids
kl_prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
kl_prompt_ids
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
cutoff_len
)
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
data_args
.
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
prompt_ids
=
prompt_ids
[:
source_len
]
response_ids
=
response_ids
[:
target_len
]
response_ids
=
response_ids
[:
target_len
]
kl_source_len
,
kl_target_len
=
infer_seqlen
(
len
(
kl_prompt_ids
),
len
(
kl_response_ids
),
cutoff_len
)
kl_source_len
,
kl_target_len
=
infer_seqlen
(
len
(
kl_prompt_ids
),
len
(
kl_response_ids
),
data_args
.
cutoff_len
)
kl_prompt_ids
=
kl_prompt_ids
[:
kl_source_len
]
kl_prompt_ids
=
kl_prompt_ids
[:
kl_source_len
]
kl_response_ids
=
kl_response_ids
[:
kl_target_len
]
kl_response_ids
=
kl_response_ids
[:
kl_target_len
]
...
@@ -120,7 +120,7 @@ def preprocess_feedback_dataset(
...
@@ -120,7 +120,7 @@ def preprocess_feedback_dataset(
template
=
template
,
template
=
template
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
processor
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
data_args
=
data_args
,
)
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
...
...
src/llamafactory/data/processors/pairwise.py
View file @
4a40151b
...
@@ -37,7 +37,7 @@ def _encode_pairwise_example(
...
@@ -37,7 +37,7 @@ def _encode_pairwise_example(
template
:
"Template"
,
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
data_args
:
"DataArguments"
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
...
@@ -55,8 +55,9 @@ def _encode_pairwise_example(
...
@@ -55,8 +55,9 @@ def _encode_pairwise_example(
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
# consider the response is more important
source_len
,
target_len
=
infer_seqlen
(
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
cutoff_len
)
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
data_args
.
cutoff_len
)
# consider the response is more important
prompt_ids
=
prompt_ids
[:
source_len
]
prompt_ids
=
prompt_ids
[:
source_len
]
chosen_ids
=
chosen_ids
[:
target_len
]
chosen_ids
=
chosen_ids
[:
target_len
]
rejected_ids
=
rejected_ids
[:
target_len
]
rejected_ids
=
rejected_ids
[:
target_len
]
...
@@ -104,7 +105,7 @@ def preprocess_pairwise_dataset(
...
@@ -104,7 +105,7 @@ def preprocess_pairwise_dataset(
template
=
template
,
template
=
template
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
processor
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
data_args
=
data_args
,
)
)
model_inputs
[
"chosen_input_ids"
].
append
(
chosen_input_ids
)
model_inputs
[
"chosen_input_ids"
].
append
(
chosen_input_ids
)
model_inputs
[
"chosen_attention_mask"
].
append
([
1
]
*
len
(
chosen_input_ids
))
model_inputs
[
"chosen_attention_mask"
].
append
([
1
]
*
len
(
chosen_input_ids
))
...
...
src/llamafactory/data/processors/supervised.py
View file @
4a40151b
...
@@ -38,9 +38,7 @@ def _encode_supervised_example(
...
@@ -38,9 +38,7 @@ def _encode_supervised_example(
template
:
"Template"
,
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
data_args
:
"DataArguments"
,
train_on_prompt
:
bool
,
mask_history
:
bool
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
...
@@ -55,34 +53,27 @@ def _encode_supervised_example(
...
@@ -55,34 +53,27 @@ def _encode_supervised_example(
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
messages
,
system
,
tools
)
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
messages
,
system
,
tools
)
total_length
=
1
if
template
.
efficient_eos
else
0
total_length
=
1
if
template
.
efficient_eos
else
0
if
mask_history
:
encoded_pairs
=
encoded_pairs
[::
-
1
]
# high priority for last turns
for
turn_idx
,
(
source_ids
,
target_ids
)
in
enumerate
(
encoded_pairs
):
for
turn_idx
,
(
source_ids
,
target_ids
)
in
enumerate
(
encoded_pairs
):
if
total_length
>=
cutoff_len
:
if
total_length
>=
data_args
.
cutoff_len
:
break
break
source_len
,
target_len
=
infer_seqlen
(
len
(
source_ids
),
len
(
target_ids
),
cutoff_len
-
total_length
)
source_len
,
target_len
=
infer_seqlen
(
len
(
source_ids
),
len
(
target_ids
),
data_args
.
cutoff_len
-
total_length
)
source_ids
=
source_ids
[:
source_len
]
source_ids
=
source_ids
[:
source_len
]
target_ids
=
target_ids
[:
target_len
]
target_ids
=
target_ids
[:
target_len
]
total_length
+=
source_len
+
target_len
total_length
+=
source_len
+
target_len
if
train_on_prompt
:
if
data_args
.
train_on_prompt
:
source_label
=
source_ids
source_label
=
source_ids
elif
template
.
efficient_eos
:
elif
turn_idx
!=
0
and
template
.
efficient_eos
:
source_label
=
[
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
source_label
=
[
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
else
:
else
:
source_label
=
[
IGNORE_INDEX
]
*
source_len
source_label
=
[
IGNORE_INDEX
]
*
source_len
if
mask_history
and
turn_idx
!=
0
:
# train on the last turn only
if
data_args
.
mask_history
and
turn_idx
!=
len
(
encoded_pairs
)
-
1
:
target_label
=
[
IGNORE_INDEX
]
*
target_len
target_label
=
[
IGNORE_INDEX
]
*
target_len
else
:
else
:
target_label
=
target_ids
target_label
=
target_ids
if
mask_history
:
# reversed sequences
input_ids
=
source_ids
+
target_ids
+
input_ids
labels
=
source_label
+
target_label
+
labels
else
:
input_ids
+=
source_ids
+
target_ids
input_ids
+=
source_ids
+
target_ids
labels
+=
source_label
+
target_label
labels
+=
source_label
+
target_label
...
@@ -121,9 +112,7 @@ def preprocess_supervised_dataset(
...
@@ -121,9 +112,7 @@ def preprocess_supervised_dataset(
template
=
template
,
template
=
template
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
processor
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
data_args
=
data_args
,
train_on_prompt
=
data_args
.
train_on_prompt
,
mask_history
=
data_args
.
mask_history
,
)
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
...
@@ -161,9 +150,7 @@ def preprocess_packed_supervised_dataset(
...
@@ -161,9 +150,7 @@ def preprocess_packed_supervised_dataset(
template
=
template
,
template
=
template
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
None
,
processor
=
None
,
cutoff_len
=
data_args
.
cutoff_len
-
1
,
# reserved for the padding token
data_args
=
data_args
,
train_on_prompt
=
data_args
.
train_on_prompt
,
mask_history
=
data_args
.
mask_history
,
)
)
length
=
len
(
input_ids
)
length
=
len
(
input_ids
)
if
length
>
data_args
.
cutoff_len
:
if
length
>
data_args
.
cutoff_len
:
...
@@ -176,7 +163,7 @@ def preprocess_packed_supervised_dataset(
...
@@ -176,7 +163,7 @@ def preprocess_packed_supervised_dataset(
valid_num
+=
1
valid_num
+=
1
model_inputs
=
{
"input_ids"
:
[],
"attention_mask"
:
[],
"labels"
:
[]}
model_inputs
=
{
"input_ids"
:
[],
"attention_mask"
:
[],
"labels"
:
[]}
knapsacks
=
greedy_knapsack
(
lengths
,
data_args
.
cutoff_len
-
1
)
# reserved for the padding token
knapsacks
=
greedy_knapsack
(
lengths
,
data_args
.
cutoff_len
)
for
knapsack
in
knapsacks
:
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
for
i
,
length
in
enumerate
(
knapsack
):
...
...
src/llamafactory/data/processors/unsupervised.py
View file @
4a40151b
...
@@ -37,7 +37,7 @@ def _encode_unsupervised_example(
...
@@ -37,7 +37,7 @@ def _encode_unsupervised_example(
template
:
"Template"
,
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
data_args
:
"DataArguments"
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
...
@@ -55,7 +55,7 @@ def _encode_unsupervised_example(
...
@@ -55,7 +55,7 @@ def _encode_unsupervised_example(
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
input_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
input_ids
input_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
input_ids
source_len
,
target_len
=
infer_seqlen
(
len
(
input_ids
),
len
(
labels
),
cutoff_len
)
source_len
,
target_len
=
infer_seqlen
(
len
(
input_ids
),
len
(
labels
),
data_args
.
cutoff_len
)
input_ids
=
input_ids
[:
source_len
]
input_ids
=
input_ids
[:
source_len
]
labels
=
labels
[:
target_len
]
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
return
input_ids
,
labels
...
@@ -88,7 +88,7 @@ def preprocess_unsupervised_dataset(
...
@@ -88,7 +88,7 @@ def preprocess_unsupervised_dataset(
template
=
template
,
template
=
template
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
processor
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
data_args
=
data_args
,
)
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
...
...
src/llamafactory/data/template.py
View file @
4a40151b
...
@@ -310,15 +310,14 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
...
@@ -310,15 +310,14 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template
+=
"{% set system_message = '"
+
_jinja_escape
(
template
.
default_system
)
+
"' %}"
jinja_template
+=
"{% set system_message = '"
+
_jinja_escape
(
template
.
default_system
)
+
"' %}"
jinja_template
+=
(
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
)
)
system_message
=
_convert_slots_to_jinja
(
template
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
system_message
=
_convert_slots_to_jinja
(
template
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
if
not
isinstance
(
template
,
Llama2Template
):
if
not
isinstance
(
template
,
Llama2Template
):
jinja_template
+=
"{% if system_message is defined %}{{ "
+
system_message
+
" }}{% endif %}"
jinja_template
+=
"{% if system_message is defined %}{{ "
+
system_message
+
" }}{% endif %}"
jinja_template
+=
"{% for message in
loop_
messages %}"
jinja_template
+=
"{% for message in messages %}"
jinja_template
+=
"{% set content = message['content'] %}"
jinja_template
+=
"{% set content = message['content'] %}"
if
isinstance
(
template
,
Llama2Template
):
if
isinstance
(
template
,
Llama2Template
):
jinja_template
+=
"{% if loop.index0 == 0 and system_message is defined %}"
jinja_template
+=
"{% if loop.index0 == 0 and system_message is defined %}"
...
@@ -579,7 +578,6 @@ _register_template(
...
@@ -579,7 +578,6 @@ _register_template(
_register_template
(
_register_template
(
name
=
"deepseek"
,
name
=
"deepseek"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n\n
Assistant:"
]),
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n\n
Assistant:"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
)
...
@@ -587,14 +585,14 @@ _register_template(
...
@@ -587,14 +585,14 @@ _register_template(
_register_template
(
_register_template
(
name
=
"deepseekcoder"
,
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
<|EOT|>
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
default_system
=
(
"You are an AI programming assistant, utilizing the Deep
S
eek Coder model, "
"You are an AI programming assistant, utilizing the Deep
s
eek Coder model, "
"developed by Deep
S
eek Company, and you only answer questions related to computer science. "
"developed by Deep
s
eek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer
.
\n
"
"and other non-computer science questions, you will refuse to answer
\n
"
),
),
)
)
...
@@ -783,20 +781,6 @@ _register_template(
...
@@ -783,20 +781,6 @@ _register_template(
)
)
_register_template
(
name
=
"sailor"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>question
\n
{{content}}<|im_end|>
\n
<|im_start|>answer
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
_register_template
(
_register_template
(
name
=
"solar"
,
name
=
"solar"
,
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
...
...
src/llamafactory/extras/constants.py
View file @
4a40151b
...
@@ -531,10 +531,6 @@ register_model_group(
...
@@ -531,10 +531,6 @@ register_model_group(
"Gemma-1.1-7B-Chat"
:
{
"Gemma-1.1-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-1.1-7b-it"
,
DownloadSource
.
DEFAULT
:
"google/gemma-1.1-7b-it"
,
},
},
"Gemma-2-2B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-2b"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-2b"
,
},
"Gemma-2-9B"
:
{
"Gemma-2-9B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-9b"
,
DownloadSource
.
DEFAULT
:
"google/gemma-2-9b"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-9b"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-9b"
,
...
@@ -543,10 +539,6 @@ register_model_group(
...
@@ -543,10 +539,6 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b"
,
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-27b"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-27b"
,
},
},
"Gemma-2-2B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-2b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-2b-it"
,
},
"Gemma-2-9B-Chat"
:
{
"Gemma-2-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-9b-it"
,
DownloadSource
.
DEFAULT
:
"google/gemma-2-9b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-9b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-9b-it"
,
...
@@ -747,35 +739,6 @@ register_model_group(
...
@@ -747,35 +739,6 @@ register_model_group(
)
)
register_model_group
(
models
=
{
"LLaMA3.1-8B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-8B"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3.1-8B"
,
},
"LLaMA3.1-70B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-70B"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3.1-70B"
,
},
"LLaMA3.1-405B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-405B"
,
},
"LLaMA3.1-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3.1-8B-Instruct"
,
},
"LLaMA3.1-70B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-70B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3.1-70B-Instruct"
,
},
"LLaMA3.1-405B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-405B-Instruct"
,
},
},
template
=
"llama3"
,
)
register_model_group
(
register_model_group
(
models
=
{
models
=
{
"LLaVA1.5-7B-Chat"
:
{
"LLaVA1.5-7B-Chat"
:
{
...
@@ -828,11 +791,6 @@ register_model_group(
...
@@ -828,11 +791,6 @@ register_model_group(
},
},
"Mistral-7B-v0.3-Chat"
:
{
"Mistral-7B-v0.3-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Mistral-7B-Instruct-v0.3"
,
},
"Mistral-Nemo-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-Nemo-Instruct-2407"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-Nemo-Instruct-2407"
,
},
},
},
},
template
=
"mistral"
,
template
=
"mistral"
,
...
@@ -1244,18 +1202,6 @@ register_model_group(
...
@@ -1244,18 +1202,6 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-57B-A14B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-57B-A14B"
,
},
},
"Qwen2-Math-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-Math-1.5B"
,
},
"Qwen2-Math-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-7B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-Math-7B"
,
},
"Qwen2-Math-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-72B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-Math-72B"
,
},
"Qwen2-0.5B-Chat"
:
{
"Qwen2-0.5B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B-Instruct"
,
...
@@ -1276,18 +1222,6 @@ register_model_group(
...
@@ -1276,18 +1222,6 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B-Instruct"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-57B-A14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-57B-A14B-Instruct"
,
},
},
"Qwen2-Math-1.5B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-Math-1.5B-Instruct"
,
},
"Qwen2-Math-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-Math-7B-Instruct"
,
},
"Qwen2-Math-72B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-Math-72B-Instruct"
,
},
"Qwen2-0.5B-int8-Chat"
:
{
"Qwen2-0.5B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
...
...
src/llamafactory/extras/env.py
View file @
4a40151b
...
@@ -26,7 +26,7 @@ import trl
...
@@ -26,7 +26,7 @@ import trl
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
VERSION
=
"0.8.
4.dev0
"
VERSION
=
"0.8.
3
"
def
print_env
()
->
None
:
def
print_env
()
->
None
:
...
...
src/llamafactory/extras/misc.py
View file @
4a40151b
...
@@ -37,7 +37,7 @@ from .logging import get_logger
...
@@ -37,7 +37,7 @@ from .logging import get_logger
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
try
:
try
:
_is_bf16_available
=
is_torch_bf16_gpu_available
()
or
(
is_torch_npu_available
()
and
torch
.
npu
.
is_bf16_supported
())
_is_bf16_available
=
is_torch_bf16_gpu_available
()
except
Exception
:
except
Exception
:
_is_bf16_available
=
False
_is_bf16_available
=
False
...
@@ -79,11 +79,11 @@ def check_dependencies() -> None:
...
@@ -79,11 +79,11 @@ def check_dependencies() -> None:
if
os
.
environ
.
get
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
if
os
.
environ
.
get
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
logger
.
warning
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
logger
.
warning
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
else
:
else
:
require_version
(
"transformers>=4.41.2
,<=4.43.4
"
,
"To fix: pip install transformers>=4.41.2
,<=4.43.4
"
)
require_version
(
"transformers>=4.41.2"
,
"To fix: pip install transformers>=4.41.2"
)
require_version
(
"datasets>=2.16.0
,<=2.20.0
"
,
"To fix: pip install datasets>=2.16.0
,<=2.20.0
"
)
require_version
(
"datasets>=2.16.0"
,
"To fix: pip install datasets>=2.16.0"
)
require_version
(
"accelerate>=0.30.1
,<=0.32.0
"
,
"To fix: pip install accelerate>=0.30.1
,<=0.32.0
"
)
require_version
(
"accelerate>=0.30.1"
,
"To fix: pip install accelerate>=0.30.1"
)
require_version
(
"peft>=0.11.1
,<=0.12.0
"
,
"To fix: pip install peft>=0.11.1
,<=0.12.0
"
)
require_version
(
"peft>=0.11.1"
,
"To fix: pip install peft>=0.11.1"
)
require_version
(
"trl>=0.8.6
,<=0.9.6
"
,
"To fix: pip install trl>=0.8.6
,<=0.9.6
"
)
require_version
(
"trl>=0.8.6"
,
"To fix: pip install trl>=0.8.6"
)
def
count_parameters
(
model
:
"torch.nn.Module"
)
->
Tuple
[
int
,
int
]:
def
count_parameters
(
model
:
"torch.nn.Module"
)
->
Tuple
[
int
,
int
]:
...
@@ -137,9 +137,7 @@ def get_device_count() -> int:
...
@@ -137,9 +137,7 @@ def get_device_count() -> int:
r
"""
r
"""
Gets the number of available GPU or NPU devices.
Gets the number of available GPU or NPU devices.
"""
"""
if
is_torch_xpu_available
():
if
is_torch_npu_available
():
return
torch
.
xpu
.
device_count
()
elif
is_torch_npu_available
():
return
torch
.
npu
.
device_count
()
return
torch
.
npu
.
device_count
()
elif
is_torch_cuda_available
():
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
device_count
()
return
torch
.
cuda
.
device_count
()
...
...
src/llamafactory/extras/packages.py
View file @
4a40151b
...
@@ -70,11 +70,6 @@ def is_starlette_available():
...
@@ -70,11 +70,6 @@ def is_starlette_available():
return
_is_package_available
(
"sse_starlette"
)
return
_is_package_available
(
"sse_starlette"
)
@
lru_cache
def
is_transformers_version_greater_than_4_43
():
return
_get_package_version
(
"transformers"
)
>=
version
.
parse
(
"4.43.0"
)
def
is_uvicorn_available
():
def
is_uvicorn_available
():
return
_is_package_available
(
"uvicorn"
)
return
_is_package_available
(
"uvicorn"
)
...
...
src/llamafactory/hparams/data_args.py
View file @
4a40151b
...
@@ -141,6 +141,3 @@ class DataArguments:
...
@@ -141,6 +141,3 @@ class DataArguments:
if
self
.
streaming
and
self
.
max_samples
is
not
None
:
if
self
.
streaming
and
self
.
max_samples
is
not
None
:
raise
ValueError
(
"`max_samples` is incompatible with `streaming`."
)
raise
ValueError
(
"`max_samples` is incompatible with `streaming`."
)
if
self
.
mask_history
and
self
.
train_on_prompt
:
raise
ValueError
(
"`mask_history` is incompatible with `train_on_prompt`."
)
src/llamafactory/hparams/finetuning_args.py
View file @
4a40151b
...
@@ -326,10 +326,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
...
@@ -326,10 +326,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to make only the parameters in the expanded blocks trainable."
},
metadata
=
{
"help"
:
"Whether or not to make only the parameters in the expanded blocks trainable."
},
)
)
use_adam_mini
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the Adam-mini optimizer."
},
)
freeze_vision_tower
:
bool
=
field
(
freeze_vision_tower
:
bool
=
field
(
default
=
True
,
default
=
True
,
metadata
=
{
"help"
:
"Whether ot not to freeze vision tower in MLLM training."
},
metadata
=
{
"help"
:
"Whether ot not to freeze vision tower in MLLM training."
},
...
...
src/llamafactory/hparams/parser.py
View file @
4a40151b
...
@@ -128,9 +128,6 @@ def _check_extra_dependencies(
...
@@ -128,9 +128,6 @@ def _check_extra_dependencies(
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
require_version
(
"badam>=1.2.1"
,
"To fix: pip install badam>=1.2.1"
)
require_version
(
"badam>=1.2.1"
,
"To fix: pip install badam>=1.2.1"
)
if
finetuning_args
.
use_adam_mini
:
require_version
(
"adam-mini"
,
"To fix: pip install adam-mini"
)
if
finetuning_args
.
plot_loss
:
if
finetuning_args
.
plot_loss
:
require_version
(
"matplotlib"
,
"To fix: pip install matplotlib"
)
require_version
(
"matplotlib"
,
"To fix: pip install matplotlib"
)
...
@@ -166,33 +163,32 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
...
@@ -166,33 +163,32 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if
finetuning_args
.
stage
!=
"pt"
and
data_args
.
template
is
None
:
if
finetuning_args
.
stage
!=
"pt"
and
data_args
.
template
is
None
:
raise
ValueError
(
"Please specify which `template` to use."
)
raise
ValueError
(
"Please specify which `template` to use."
)
if
finetuning_args
.
stage
!=
"sft"
:
if
finetuning_args
.
stage
!=
"sft"
and
training_args
.
predict_with_generate
:
if
training_args
.
predict_with_generate
:
raise
ValueError
(
"`predict_with_generate` cannot be set as True except SFT."
)
raise
ValueError
(
"`predict_with_generate` cannot be set as True except SFT."
)
if
data_args
.
neat_packing
:
if
finetuning_args
.
stage
!=
"sft"
and
data_args
.
neat_packing
:
raise
ValueError
(
"`neat_packing` cannot be set as True except SFT."
)
raise
ValueError
(
"`neat_packing` cannot be set as True except SFT."
)
if
data_args
.
train_on_prompt
or
data_args
.
mask_history
:
raise
ValueError
(
"`train_on_prompt` or `mask_history` cannot be set as True except SFT."
)
if
finetuning_args
.
stage
==
"sft"
and
training_args
.
do_predict
and
not
training_args
.
predict_with_generate
:
if
finetuning_args
.
stage
==
"sft"
and
training_args
.
do_predict
and
not
training_args
.
predict_with_generate
:
raise
ValueError
(
"Please enable `predict_with_generate` to save model predictions."
)
raise
ValueError
(
"Please enable `predict_with_generate` to save model predictions."
)
if
finetuning_args
.
stage
in
[
"rm"
,
"ppo"
]
and
training_args
.
load_best_model_at_end
:
if
finetuning_args
.
stage
in
[
"rm"
,
"ppo"
]
and
training_args
.
load_best_model_at_end
:
raise
ValueError
(
"RM and PPO stages do not support `load_best_model_at_end`."
)
raise
ValueError
(
"RM and PPO stages do not support `load_best_model_at_end`."
)
if
finetuning_args
.
stage
==
"ppo"
:
if
finetuning_args
.
stage
==
"ppo"
and
not
training_args
.
do_train
:
if
not
training_args
.
do_train
:
raise
ValueError
(
"PPO training does not support evaluation, use the SFT stage to evaluate models."
)
raise
ValueError
(
"PPO training does not support evaluation, use the SFT stage to evaluate models."
)
if
model_args
.
shift_attn
:
if
finetuning_args
.
stage
==
"ppo"
and
model_args
.
shift_attn
:
raise
ValueError
(
"PPO training is incompatible with S^2-Attn."
)
raise
ValueError
(
"PPO training is incompatible with S^2-Attn."
)
if
finetuning_args
.
reward_model_type
==
"lora"
and
model_args
.
use_unsloth
:
if
finetuning_args
.
stage
==
"ppo"
and
finetuning_args
.
reward_model_type
==
"lora"
and
model_args
.
use_unsloth
:
raise
ValueError
(
"Unsloth does not support lora reward model."
)
raise
ValueError
(
"Unsloth does not support lora reward model."
)
if
training_args
.
report_to
and
training_args
.
report_to
[
0
]
not
in
[
"wandb"
,
"tensorboard"
]:
if
(
finetuning_args
.
stage
==
"ppo"
and
training_args
.
report_to
and
training_args
.
report_to
[
0
]
not
in
[
"wandb"
,
"tensorboard"
]
):
raise
ValueError
(
"PPO only accepts wandb or tensorboard logger."
)
raise
ValueError
(
"PPO only accepts wandb or tensorboard logger."
)
if
training_args
.
parallel_mode
==
ParallelMode
.
NOT_DISTRIBUTED
:
if
training_args
.
parallel_mode
==
ParallelMode
.
NOT_DISTRIBUTED
:
...
...
src/llamafactory/model/model_utils/attention.py
View file @
4a40151b
...
@@ -36,7 +36,7 @@ def configure_attn_implementation(
...
@@ -36,7 +36,7 @@ def configure_attn_implementation(
if
model_args
.
flash_attn
==
"auto"
or
model_args
.
flash_attn
==
"fa2"
:
if
model_args
.
flash_attn
==
"auto"
or
model_args
.
flash_attn
==
"fa2"
:
if
is_flash_attn_2_available
():
if
is_flash_attn_2_available
():
require_version
(
"transformers>=4.42.4"
,
"To fix: pip install transformers>=4.42.4"
)
require_version
(
"transformers>=4.42.4"
,
"To fix: pip install transformers>=4.42.4"
)
require_version
(
"flash_attn>=2.6.
3
"
,
"To fix: pip install flash_attn>=2.6.
3
"
)
require_version
(
"flash_attn>=2.6.
0
"
,
"To fix: pip install flash_attn>=2.6.
0
"
)
logger
.
warning
(
"Gemma-2 should use flash attention 2, change `flash_attn` to fa2."
)
logger
.
warning
(
"Gemma-2 should use flash attention 2, change `flash_attn` to fa2."
)
model_args
.
flash_attn
=
"fa2"
model_args
.
flash_attn
=
"fa2"
else
:
else
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment