Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
8293100a
Commit
8293100a
authored
Jan 16, 2025
by
luopl
Browse files
update to 0.9.2.dev0
parent
2778a3d0
Changes
124
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
521 additions
and
211 deletions
+521
-211
src/llamafactory/hparams/__init__.py
src/llamafactory/hparams/__init__.py
+6
-1
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+6
-3
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+105
-7
src/llamafactory/hparams/generating_args.py
src/llamafactory/hparams/generating_args.py
+14
-1
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+14
-1
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+72
-48
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+48
-0
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+8
-8
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+3
-3
src/llamafactory/model/model_utils/checkpointing.py
src/llamafactory/model/model_utils/checkpointing.py
+4
-2
src/llamafactory/model/model_utils/longlora.py
src/llamafactory/model/model_utils/longlora.py
+7
-10
src/llamafactory/model/model_utils/misc.py
src/llamafactory/model/model_utils/misc.py
+8
-13
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+3
-2
src/llamafactory/model/model_utils/packing.py
src/llamafactory/model/model_utils/packing.py
+9
-43
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+11
-12
src/llamafactory/model/model_utils/unsloth.py
src/llamafactory/model/model_utils/unsloth.py
+1
-1
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+113
-31
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+12
-2
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+54
-9
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+23
-14
No files found.
src/llamafactory/hparams/__init__.py
View file @
8293100a
...
...
@@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
from
.finetuning_args
import
FinetuningArguments
from
.generating_args
import
GeneratingArguments
from
.model_args
import
ModelArguments
from
.parser
import
get_eval_args
,
get_infer_args
,
get_train_args
from
.parser
import
get_eval_args
,
get_infer_args
,
get_ray_args
,
get_train_args
,
read_args
from
.training_args
import
RayArguments
,
TrainingArguments
__all__
=
[
...
...
@@ -26,7 +27,11 @@ __all__ = [
"FinetuningArguments"
,
"GeneratingArguments"
,
"ModelArguments"
,
"RayArguments"
,
"TrainingArguments"
,
"get_eval_args"
,
"get_infer_args"
,
"get_ray_args"
,
"get_train_args"
,
"read_args"
,
]
src/llamafactory/hparams/data_args.py
View file @
8293100a
...
...
@@ -15,8 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
Literal
,
Optional
@
dataclass
...
...
@@ -99,7 +99,7 @@ class DataArguments:
)
val_size
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Size of the
development
set, should be an integer or a float in range `[0,1)`."
},
metadata
=
{
"help"
:
"Size of the
validation
set, should be an integer or a float in range `[0,1)`."
},
)
packing
:
Optional
[
bool
]
=
field
(
default
=
None
,
...
...
@@ -161,3 +161,6 @@ class DataArguments:
if
self
.
mask_history
and
self
.
train_on_prompt
:
raise
ValueError
(
"`mask_history` is incompatible with `train_on_prompt`."
)
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
return
asdict
(
self
)
src/llamafactory/hparams/finetuning_args.py
View file @
8293100a
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Literal
,
Optional
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
@
dataclass
...
...
@@ -251,6 +251,59 @@ class GaloreArguments:
)
@
dataclass
class
ApolloArguments
:
r
"""
Arguments pertaining to the APOLLO algorithm.
"""
use_apollo
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the APOLLO optimizer."
},
)
apollo_target
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
(
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
apollo_rank
:
int
=
field
(
default
=
16
,
metadata
=
{
"help"
:
"The rank of APOLLO gradients."
},
)
apollo_update_interval
:
int
=
field
(
default
=
200
,
metadata
=
{
"help"
:
"Number of steps to update the APOLLO projection."
},
)
apollo_scale
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"APOLLO scaling coefficient."
},
)
apollo_proj
:
Literal
[
"svd"
,
"random"
]
=
field
(
default
=
"random"
,
metadata
=
{
"help"
:
"Type of APOLLO low-rank projection algorithm (svd or random)."
},
)
apollo_proj_type
:
Literal
[
"std"
,
"right"
,
"left"
]
=
field
(
default
=
"std"
,
metadata
=
{
"help"
:
"Type of APOLLO projection."
},
)
apollo_scale_type
:
Literal
[
"channel"
,
"tensor"
]
=
field
(
default
=
"channel"
,
metadata
=
{
"help"
:
"Type of APOLLO scaling (channel or tensor)."
},
)
apollo_layerwise
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to enable layer-wise update to further save memory."
},
)
apollo_scale_front
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the norm-growth limiter in front of gradient scaling."
},
)
@
dataclass
class
BAdamArgument
:
r
"""
...
...
@@ -305,7 +358,37 @@ class BAdamArgument:
@
dataclass
class
FinetuningArguments
(
FreezeArguments
,
LoraArguments
,
RLHFArguments
,
GaloreArguments
,
BAdamArgument
):
class
SwanLabArguments
:
use_swanlab
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the SwanLab (an experiment tracking and visualization tool)."
},
)
swanlab_project
:
str
=
field
(
default
=
"llamafactory"
,
metadata
=
{
"help"
:
"The project name in SwanLab."
},
)
swanlab_workspace
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The workspace name in SwanLab."
},
)
swanlab_run_name
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The experiment name in SwanLab."
},
)
swanlab_mode
:
Literal
[
"cloud"
,
"local"
]
=
field
(
default
=
"cloud"
,
metadata
=
{
"help"
:
"The mode of SwanLab."
},
)
swanlab_api_key
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The API key for SwanLab."
},
)
@
dataclass
class
FinetuningArguments
(
FreezeArguments
,
LoraArguments
,
RLHFArguments
,
GaloreArguments
,
ApolloArguments
,
BAdamArgument
,
SwanLabArguments
):
r
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
...
...
@@ -334,6 +417,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default
=
True
,
metadata
=
{
"help"
:
"Whether ot not to freeze vision tower in MLLM training."
},
)
freeze_multi_modal_projector
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to freeze the multi modal projector in MLLM training."
},
)
train_mm_proj_only
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to train the multimodal projector for MLLM only."
},
...
...
@@ -342,6 +429,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to compute the token-level accuracy at evaluation."
},
)
disable_shuffling
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to disable the shuffling of the training set."
},
)
plot_loss
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to save the training loss curves."
},
...
...
@@ -363,7 +454,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
self
.
lora_target
:
List
[
str
]
=
split_arg
(
self
.
lora_target
)
self
.
additional_target
:
Optional
[
List
[
str
]]
=
split_arg
(
self
.
additional_target
)
self
.
galore_target
:
List
[
str
]
=
split_arg
(
self
.
galore_target
)
self
.
apollo_target
:
List
[
str
]
=
split_arg
(
self
.
apollo_target
)
self
.
freeze_vision_tower
=
self
.
freeze_vision_tower
or
self
.
train_mm_proj_only
self
.
freeze_multi_modal_projector
=
self
.
freeze_multi_modal_projector
and
not
self
.
train_mm_proj_only
self
.
use_ref_model
=
self
.
stage
==
"dpo"
and
self
.
pref_loss
not
in
[
"orpo"
,
"simpo"
]
assert
self
.
finetuning_type
in
[
"lora"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
...
...
@@ -382,11 +475,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if
self
.
use_llama_pro
and
self
.
finetuning_type
==
"full"
:
raise
ValueError
(
"`use_llama_pro` is only valid for Freeze or LoRA training."
)
if
self
.
finetuning_type
==
"lora"
and
(
self
.
use_galore
or
self
.
use_badam
):
raise
ValueError
(
"Cannot use LoRA with GaLore or BAdam together."
)
if
self
.
finetuning_type
==
"lora"
and
(
self
.
use_galore
or
self
.
use_apollo
or
self
.
use_badam
):
raise
ValueError
(
"Cannot use LoRA with GaLore
, APOLLO
or BAdam together."
)
if
self
.
use_galore
and
self
.
use_badam
:
raise
ValueError
(
"Cannot use GaLore
with
BAdam together."
)
if
int
(
self
.
use_galore
)
+
int
(
self
.
use_apollo
)
+
(
self
.
use_badam
)
>
1
:
raise
ValueError
(
"Cannot use GaLore
, APOLLO or
BAdam together."
)
if
self
.
pissa_init
and
(
self
.
stage
in
[
"ppo"
,
"kto"
]
or
self
.
use_ref_model
):
raise
ValueError
(
"Cannot use PiSSA for current training stage."
)
...
...
@@ -406,3 +499,8 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if
self
.
pissa_init
:
raise
ValueError
(
"`pissa_init` is only valid for LoRA training."
)
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
args
=
asdict
(
self
)
args
=
{
k
:
f
"<
{
k
.
upper
()
}
>"
if
k
.
endswith
(
"api_key"
)
else
v
for
k
,
v
in
args
.
items
()}
return
args
src/llamafactory/hparams/generating_args.py
View file @
8293100a
...
...
@@ -15,6 +15,8 @@
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
Optional
from
transformers
import
GenerationConfig
@
dataclass
class
GeneratingArguments
:
...
...
@@ -64,11 +66,22 @@ class GeneratingArguments:
default
=
None
,
metadata
=
{
"help"
:
"Default system message to use in chat completion."
},
)
skip_special_tokens
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to remove special tokens in the decoding."
},
)
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
to_dict
(
self
,
obey_generation_config
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
args
=
asdict
(
self
)
if
args
.
get
(
"max_new_tokens"
,
-
1
)
>
0
:
args
.
pop
(
"max_length"
,
None
)
else
:
args
.
pop
(
"max_new_tokens"
,
None
)
if
obey_generation_config
:
generation_config
=
GenerationConfig
()
for
key
in
list
(
args
.
keys
()):
if
not
hasattr
(
generation_config
,
key
):
args
.
pop
(
key
)
return
args
src/llamafactory/hparams/model_args.py
View file @
8293100a
...
...
@@ -16,7 +16,7 @@
# limitations under the License.
import
json
from
dataclasses
import
dataclass
,
field
,
fields
from
dataclasses
import
asdict
,
dataclass
,
field
,
fields
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Union
import
torch
...
...
@@ -237,6 +237,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to disable gradient checkpointing."
},
)
use_reentrant_gc
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use reentrant gradient checkpointing."
},
)
upcast_layernorm
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to upcast the layernorm weights in fp32."
},
...
...
@@ -281,6 +285,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default
=
False
,
metadata
=
{
"help"
:
"For debugging purposes, print the status of the parameters in the model."
},
)
trust_remote_code
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to trust the execution of code from datasets/models defined on the Hub or not."
},
)
compute_dtype
:
Optional
[
torch
.
dtype
]
=
field
(
default
=
None
,
init
=
False
,
...
...
@@ -336,3 +344,8 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
setattr
(
result
,
name
,
value
)
return
result
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
args
=
asdict
(
self
)
args
=
{
k
:
f
"<
{
k
.
upper
()
}
>"
if
k
.
endswith
(
"token"
)
else
v
for
k
,
v
in
args
.
items
()}
return
args
src/llamafactory/hparams/parser.py
View file @
8293100a
...
...
@@ -15,56 +15,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
sys
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
transformers
from
transformers
import
HfArgumentParser
,
Seq2SeqTrainingArguments
import
yaml
from
transformers
import
HfArgumentParser
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.trainer_utils
import
get_last_checkpoint
from
transformers.training_args
import
ParallelMode
from
transformers.utils
import
is_torch_bf16_gpu_available
,
is_torch_npu_available
from
transformers.utils.versions
import
require_version
from
..extras
import
logging
from
..extras.constants
import
CHECKPOINT_NAMES
from
..extras.misc
import
check_dependencies
,
get_current_device
from
..extras.misc
import
check_dependencies
,
check_version
,
get_current_device
from
.data_args
import
DataArguments
from
.evaluation_args
import
EvaluationArguments
from
.finetuning_args
import
FinetuningArguments
from
.generating_args
import
GeneratingArguments
from
.model_args
import
ModelArguments
from
.training_args
import
RayArguments
,
TrainingArguments
logger
=
logging
.
get_logger
(
__name__
)
check_dependencies
()
_TRAIN_ARGS
=
[
ModelArguments
,
DataArguments
,
Seq2Seq
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_CLS
=
Tuple
[
ModelArguments
,
DataArguments
,
Seq2Seq
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_ARGS
=
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_CLS
=
Tuple
[
ModelArguments
,
DataArguments
,
TrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_ARGS
=
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_INFER_CLS
=
Tuple
[
ModelArguments
,
DataArguments
,
FinetuningArguments
,
GeneratingArguments
]
_EVAL_ARGS
=
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
_EVAL_CLS
=
Tuple
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
def
_parse_args
(
parser
:
"HfArgumentParser"
,
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
Tuple
[
Any
]:
def
read_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
]
,
List
[
str
]]
]
=
None
)
->
Union
[
Dict
[
str
,
Any
],
List
[
str
]
]:
if
args
is
not
None
:
return
parser
.
parse_dict
(
args
)
return
args
if
len
(
sys
.
argv
)
==
2
and
(
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
)):
return
parser
.
parse_yaml_file
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]))
return
yaml
.
safe_load
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
elif
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
".json"
):
return
json
.
loads
(
Path
(
sys
.
argv
[
1
]).
absolute
().
read_text
())
else
:
return
sys
.
argv
[
1
:]
if
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
".json"
):
return
parser
.
parse_json_file
(
os
.
path
.
abspath
(
sys
.
argv
[
1
]))
(
*
parsed_args
,
unknown_args
)
=
parser
.
parse_args_into_dataclasses
(
return_remaining_strings
=
True
)
def
_parse_args
(
parser
:
"HfArgumentParser"
,
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
,
allow_extra_keys
:
bool
=
False
)
->
Tuple
[
Any
]:
args
=
read_args
(
args
)
if
isinstance
(
args
,
dict
):
return
parser
.
parse_dict
(
args
,
allow_extra_keys
=
allow_extra_keys
)
if
unknown_args
:
(
*
parsed_args
,
unknown_args
)
=
parser
.
parse_args_into_dataclasses
(
args
=
args
,
return_remaining_strings
=
True
)
if
unknown_args
and
not
allow_extra_keys
:
print
(
parser
.
format_help
())
print
(
f
"Got unknown args, potentially deprecated arguments:
{
unknown_args
}
"
)
raise
ValueError
(
f
"Some specified arguments are not used by the HfArgumentParser:
{
unknown_args
}
"
)
...
...
@@ -110,54 +121,64 @@ def _verify_model_args(
def
_check_extra_dependencies
(
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
training_args
:
Optional
[
"
Seq2Seq
TrainingArguments"
]
=
None
,
training_args
:
Optional
[
"TrainingArguments"
]
=
None
,
)
->
None
:
if
model_args
.
use_unsloth
:
require
_version
(
"unsloth"
,
"Please install unsloth: https://github.com/unslothai/unsloth"
)
check
_version
(
"unsloth"
,
mandatory
=
True
)
if
model_args
.
enable_liger_kernel
:
require
_version
(
"liger-kernel"
,
"To fix: pip install liger-kernel"
)
check
_version
(
"liger-kernel"
,
mandatory
=
True
)
if
model_args
.
mixture_of_depths
is
not
None
:
require
_version
(
"mixture-of-depth>=1.1.6"
,
"To fix: pip install mixture-of-depth>=1.1.6"
)
check
_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
"vllm"
:
require_version
(
"vllm>=0.4.3,<0.6.4"
,
"To fix: pip install vllm>=0.4.3,<0.6.4"
)
check_version
(
"vllm>=0.4.3,<=0.6.5"
)
check_version
(
"vllm"
,
mandatory
=
True
)
if
finetuning_args
.
use_galore
:
require_version
(
"galore_torch"
,
"To fix: pip install galore_torch"
)
check_version
(
"galore_torch"
,
mandatory
=
True
)
if
finetuning_args
.
use_apollo
:
check_version
(
"apollo_torch"
,
mandatory
=
True
)
if
finetuning_args
.
use_badam
:
require
_version
(
"badam>=1.2.1"
,
"To fix: pip install badam>=1.2.1"
)
check
_version
(
"badam>=1.2.1"
,
mandatory
=
True
)
if
finetuning_args
.
use_adam_mini
:
require
_version
(
"adam-mini"
,
"To fix: pip install adam-mini"
)
check
_version
(
"adam-mini"
,
mandatory
=
True
)
if
finetuning_args
.
plot_loss
:
require
_version
(
"matplotlib"
,
"To fix: pip install matplotlib"
)
check
_version
(
"matplotlib"
,
mandatory
=
True
)
if
training_args
is
not
None
and
training_args
.
predict_with_generate
:
require
_version
(
"jieba"
,
"To fix: pip install jieba"
)
require
_version
(
"nltk"
,
"To fix: pip install nltk"
)
require
_version
(
"rouge_chinese"
,
"To fix: pip install rouge-chinese"
)
check
_version
(
"jieba"
,
mandatory
=
True
)
check
_version
(
"nltk"
,
mandatory
=
True
)
check
_version
(
"rouge_chinese"
,
mandatory
=
True
)
def
_parse_train_args
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
_TRAIN_CLS
:
def
_parse_train_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
]
,
List
[
str
]]
]
=
None
)
->
_TRAIN_CLS
:
parser
=
HfArgumentParser
(
_TRAIN_ARGS
)
return
_parse_args
(
parser
,
args
)
def
_parse_infer_args
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
_INFER_CLS
:
def
_parse_infer_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
]
,
List
[
str
]]
]
=
None
)
->
_INFER_CLS
:
parser
=
HfArgumentParser
(
_INFER_ARGS
)
return
_parse_args
(
parser
,
args
)
def
_parse_eval_args
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
_EVAL_CLS
:
def
_parse_eval_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
]
,
List
[
str
]]
]
=
None
)
->
_EVAL_CLS
:
parser
=
HfArgumentParser
(
_EVAL_ARGS
)
return
_parse_args
(
parser
,
args
)
def
get_train_args
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
_TRAIN_CLS
:
def
get_ray_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
RayArguments
:
parser
=
HfArgumentParser
(
RayArguments
)
(
ray_args
,)
=
_parse_args
(
parser
,
args
,
allow_extra_keys
=
True
)
return
ray_args
def
get_train_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
],
List
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
_parse_train_args
(
args
)
# Setup logging
...
...
@@ -237,21 +258,21 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"`pure_bf16` is incompatible with DeepSpeed ZeRO-3."
)
if
(
finetuning_args
.
use_galore
and
finetuning_args
.
galore_layerwise
and
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
):
if
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
:
if
finetuning_args
.
use_galore
and
finetuning_args
.
galore_layerwise
:
raise
ValueError
(
"Distributed training does not support layer-wise GaLore."
)
if
finetuning_args
.
use_badam
and
training_args
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
:
if
finetuning_args
.
use_apollo
and
finetuning_args
.
apollo_layerwise
:
raise
ValueError
(
"Distributed training does not support layer-wise APOLLO."
)
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
badam_mode
==
"ratio"
:
raise
ValueError
(
"Radio-based BAdam does not yet support distributed training, use layer-wise BAdam."
)
elif
not
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"Layer-wise BAdam only supports DeepSpeed ZeRO-3 training."
)
if
finetuning_args
.
use_galore
and
training_args
.
deepspeed
is
not
None
:
raise
ValueError
(
"GaLore
is
incompatible with DeepSpeed yet."
)
if
training_args
.
deepspeed
is
not
None
and
(
finetuning_args
.
use_galore
or
finetuning_args
.
use_apollo
)
:
raise
ValueError
(
"GaLore
and APOLLO are
incompatible with DeepSpeed yet."
)
if
model_args
.
infer_backend
==
"vllm"
:
raise
ValueError
(
"vLLM backend is only available for API, CLI and Web."
)
...
...
@@ -283,9 +304,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if
training_args
.
do_train
and
(
not
training_args
.
fp16
)
and
(
not
training_args
.
bf16
):
logger
.
warning_rank0
(
"We recommend enable mixed precision training."
)
if
training_args
.
do_train
and
finetuning_args
.
use_galore
and
not
finetuning_args
.
pure_bf16
:
if
(
training_args
.
do_train
and
(
finetuning_args
.
use_galore
or
finetuning_args
.
use_apollo
)
and
not
finetuning_args
.
pure_bf16
):
logger
.
warning_rank0
(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
"Using GaLore
or APOLLO
with mixed precision training may significantly increases GPU memory usage."
)
if
(
not
training_args
.
do_train
)
and
model_args
.
quantization_bit
is
not
None
:
...
...
@@ -361,13 +386,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
str
(
model_args
.
compute_dtype
),
)
)
transformers
.
set_seed
(
training_args
.
seed
)
return
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
def
get_infer_args
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
_INFER_CLS
:
def
get_infer_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
]
,
List
[
str
]]
]
=
None
)
->
_INFER_CLS
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
_parse_infer_args
(
args
)
_set_transformers_logging
()
...
...
@@ -400,7 +424,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
return
model_args
,
data_args
,
finetuning_args
,
generating_args
def
get_eval_args
(
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
_EVAL_CLS
:
def
get_eval_args
(
args
:
Optional
[
Union
[
Dict
[
str
,
Any
]
,
List
[
str
]]
]
=
None
)
->
_EVAL_CLS
:
model_args
,
data_args
,
eval_args
,
finetuning_args
=
_parse_eval_args
(
args
)
_set_transformers_logging
()
...
...
src/llamafactory/hparams/training_args.py
0 → 100644
View file @
8293100a
import
json
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
,
Union
from
transformers
import
Seq2SeqTrainingArguments
from
transformers.training_args
import
_convert_str_dict
from
..extras.misc
import
use_ray
@
dataclass
class
RayArguments
:
r
"""
Arguments pertaining to the Ray training.
"""
ray_run_name
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The training results will be saved at `saves/ray_run_name`."
},
)
ray_num_workers
:
int
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"The number of workers for Ray training. Default is 1 worker."
},
)
resources_per_worker
:
Union
[
dict
,
str
]
=
field
(
default_factory
=
lambda
:
{
"GPU"
:
1
},
metadata
=
{
"help"
:
"The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)
placement_strategy
:
Literal
[
"SPREAD"
,
"PACK"
,
"STRICT_SPREAD"
,
"STRICT_PACK"
]
=
field
(
default
=
"PACK"
,
metadata
=
{
"help"
:
"The placement strategy for Ray training. Default is PACK."
},
)
def
__post_init__
(
self
):
self
.
use_ray
=
use_ray
()
if
isinstance
(
self
.
resources_per_worker
,
str
)
and
self
.
resources_per_worker
.
startswith
(
"{"
):
self
.
resources_per_worker
=
_convert_str_dict
(
json
.
loads
(
self
.
resources_per_worker
))
@
dataclass
class
TrainingArguments
(
RayArguments
,
Seq2SeqTrainingArguments
):
r
"""
Arguments pertaining to the trainer.
"""
def
__post_init__
(
self
):
Seq2SeqTrainingArguments
.
__post_init__
(
self
)
RayArguments
.
__post_init__
(
self
)
src/llamafactory/model/loader.py
View file @
8293100a
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
TypedDict
import
torch
...
...
@@ -52,7 +53,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
skip_check_imports
()
model_args
.
model_name_or_path
=
try_download_model_from_other_hub
(
model_args
)
return
{
"trust_remote_code"
:
Tru
e
,
"trust_remote_code"
:
model_args
.
trust_remote_cod
e
,
"cache_dir"
:
model_args
.
cache_dir
,
"revision"
:
model_args
.
model_revision
,
"token"
:
model_args
.
hf_hub_token
,
...
...
@@ -85,6 +86,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
except
Exception
as
e
:
raise
OSError
(
"Failed to load tokenizer."
)
from
e
if
model_args
.
model_max_length
is
not
None
and
tokenizer
.
model_max_length
!=
model_args
.
model_max_length
:
tokenizer
.
model_max_length
=
model_args
.
model_max_length
if
model_args
.
new_special_tokens
is
not
None
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
model_args
.
new_special_tokens
),
...
...
@@ -155,7 +159,7 @@ def load_model(
load_class
=
AutoModelForCausalLM
if
model_args
.
train_from_scratch
:
model
=
load_class
.
from_config
(
config
,
trust_remote_code
=
Tru
e
)
model
=
load_class
.
from_config
(
config
,
trust_remote_code
=
model_args
.
trust_remote_cod
e
)
else
:
model
=
load_class
.
from_pretrained
(
**
init_kwargs
)
...
...
@@ -202,12 +206,8 @@ def load_model(
logger
.
info_rank0
(
param_stats
)
if
model_args
.
print_param_status
:
if
model_args
.
print_param_status
and
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
==
0
:
for
name
,
param
in
model
.
named_parameters
():
print
(
"name: {}, dtype: {}, device: {}, trainable: {}"
.
format
(
name
,
param
.
dtype
,
param
.
device
,
param
.
requires_grad
)
)
print
(
f
"name:
{
name
}
, dtype:
{
param
.
dtype
}
, device:
{
param
.
device
}
, trainable:
{
param
.
requires_grad
}
"
)
return
model
src/llamafactory/model/model_utils/attention.py
View file @
8293100a
...
...
@@ -15,9 +15,9 @@
from
typing
import
TYPE_CHECKING
from
transformers.utils
import
is_flash_attn_2_available
,
is_torch_sdpa_available
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.misc
import
check_version
if
TYPE_CHECKING
:
...
...
@@ -35,8 +35,8 @@ def configure_attn_implementation(
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
and
is_trainable
:
if
model_args
.
flash_attn
==
"auto"
or
model_args
.
flash_attn
==
"fa2"
:
if
is_flash_attn_2_available
():
require
_version
(
"
transformers>=4.42.4"
,
"To fix: pip install
transformers>=4.42.4"
)
require
_version
(
"
flash_attn>=2.6.3"
,
"To fix: pip install
flash_attn>=2.6.3"
)
check
_version
(
"transformers>=4.42.4"
)
check
_version
(
"flash_attn>=2.6.3"
)
if
model_args
.
flash_attn
!=
"fa2"
:
logger
.
warning_rank0
(
"Gemma-2 should use flash attention 2, change `flash_attn` to fa2."
)
model_args
.
flash_attn
=
"fa2"
...
...
src/llamafactory/model/model_utils/checkpointing.py
View file @
8293100a
...
...
@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if
"value"
in
inspect
.
signature
(
self
.
_set_gradient_checkpointing
).
parameters
:
# old GC format
self
.
apply
(
partial
(
self
.
_set_gradient_checkpointing
,
value
=
True
))
self
.
enable_input_require_grads
()
logger
.
warning_once
(
"You are using the old GC format, some features (e.g. BAdam) will be invalid."
)
logger
.
warning_
rank0_
once
(
"You are using the old GC format, some features (e.g. BAdam) will be invalid."
)
else
:
# have already enabled input require gradients
self
.
_set_gradient_checkpointing
(
enable
=
True
,
gradient_checkpointing_func
=
gradient_checkpointing_func
)
...
...
@@ -156,7 +156,9 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
_gradient_checkpointing_enable
,
use_unsloth_gc
=
model_args
.
use_unsloth_gc
)
model
.
gradient_checkpointing_enable
=
MethodType
(
gradient_checkpointing_enable
,
model
)
model
.
gradient_checkpointing_enable
(
gradient_checkpointing_kwargs
=
{
"use_reentrant"
:
True
})
model
.
gradient_checkpointing_enable
(
gradient_checkpointing_kwargs
=
{
"use_reentrant"
:
model_args
.
use_reentrant_gc
}
)
setattr
(
model
.
config
,
"use_cache"
,
False
)
# turn off when gradient checkpointing is enabled
logger
.
info_rank0
(
"Gradient checkpointing enabled."
)
...
...
src/llamafactory/model/model_utils/longlora.py
View file @
8293100a
...
...
@@ -23,21 +23,18 @@ from typing import TYPE_CHECKING, Optional, Tuple
import
torch
import
torch.nn
as
nn
import
transformers
from
transformers.models.llama.modeling_llama
import
(
Cache
,
LlamaAttention
,
LlamaFlashAttention2
,
LlamaSdpaAttention
,
apply_rotary_pos_emb
,
repeat_kv
,
)
from
transformers.utils.versions
import
require_version
from
transformers.models.llama.modeling_llama
import
Cache
,
apply_rotary_pos_emb
,
repeat_kv
from
...extras
import
logging
from
...extras.constants
import
SUPPORTED_CLASS_FOR_S2ATTN
from
...extras.misc
import
check_version
from
...extras.packages
import
is_transformers_version_greater_than
if
not
is_transformers_version_greater_than
(
"4.48.0"
):
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaFlashAttention2
,
LlamaSdpaAttention
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
...
...
@@ -353,7 +350,7 @@ def llama_sdpa_attention_forward(
def
_apply_llama_patch
()
->
None
:
require
_version
(
"
transformers>=4.41.2,<=4.46.1"
,
"To fix: pip install
transformers>=4.41.2,<=4.46.1"
)
check
_version
(
"transformers>=4.41.2,<=4.46.1"
)
LlamaAttention
.
forward
=
llama_attention_forward
LlamaFlashAttention2
.
forward
=
llama_flash_attention_2_forward
LlamaSdpaAttention
.
forward
=
llama_sdpa_attention_forward
...
...
src/llamafactory/model/model_utils/misc.py
View file @
8293100a
...
...
@@ -15,6 +15,7 @@
from
typing
import
TYPE_CHECKING
,
List
from
...extras
import
logging
from
.visual
import
COMPOSITE_MODELS
if
TYPE_CHECKING
:
...
...
@@ -26,7 +27,7 @@ logger = logging.get_logger(__name__)
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
List
[
str
]:
r
"""
Finds all available modules to apply
lora or galore
.
Finds all available modules to apply
LoRA, GaLore or APOLLO
.
"""
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
forbidden_modules
=
{
"lm_head"
}
...
...
@@ -34,18 +35,12 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules
.
add
(
"output_layer"
)
elif
model_type
==
"internlm2"
:
forbidden_modules
.
add
(
"output"
)
elif
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"mllama"
,
"paligemma"
,
"video_llava"
]:
forbidden_modules
.
add
(
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"merger"
)
if
freeze_vision_tower
:
if
model_type
==
"mllama"
:
forbidden_modules
.
add
(
"vision_model"
)
elif
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"visual"
)
else
:
forbidden_modules
.
add
(
"vision_tower"
)
if
model_type
in
COMPOSITE_MODELS
:
forbidden_modules
.
add
(
COMPOSITE_MODELS
[
model_type
].
projector_key
)
if
freeze_vision_tower
and
model_type
in
COMPOSITE_MODELS
:
forbidden_modules
.
update
(
COMPOSITE_MODELS
[
model_type
].
vision_model_keys
)
module_names
=
set
()
for
name
,
module
in
model
.
named_modules
():
...
...
src/llamafactory/model/model_utils/moe.py
View file @
8293100a
...
...
@@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.utils.versions
import
require_version
from
...extras.misc
import
check_version
if
TYPE_CHECKING
:
...
...
@@ -26,7 +27,7 @@ if TYPE_CHECKING:
def
_set_z3_leaf_modules
(
model
:
"PreTrainedModel"
,
leaf_modules
:
Sequence
[
"torch.nn.Module"
])
->
None
:
require
_version
(
"
deepspeed>=0.13.0"
,
"To fix: pip install
deepspeed>=0.13.0"
)
check
_version
(
"deepspeed>=0.13.0"
)
from
deepspeed.utils
import
set_z3_leaf_modules
# type: ignore
set_z3_leaf_modules
(
model
,
leaf_modules
)
...
...
src/llamafactory/model/model_utils/packing.py
View file @
8293100a
...
...
@@ -41,16 +41,17 @@ from typing import TYPE_CHECKING, Tuple
import
torch
import
torch.nn.functional
as
F
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.
constants
import
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from
...extras.
misc
import
check_version
from
...extras.packages
import
is_transformers_version_greater_than
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
if
is_transformers_version_greater_than
(
"4.43.0"
):
import
transformers.modeling_flash_attention_utils
if
TYPE_CHECKING
:
from
...hparams
import
ModelArguments
...
...
@@ -113,45 +114,10 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
return
indices
,
cu_seqlens
,
max_seqlen_in_batch
def
_patch_for_block_diag_attn
(
model_type
:
str
)
->
None
:
require_version
(
"transformers>=4.41.2,<=4.46.1"
,
"To fix: pip install transformers>=4.41.2,<=4.46.1"
)
if
is_transformers_version_greater_than
(
"4.43.0"
):
import
transformers.modeling_flash_attention_utils
transformers
.
modeling_flash_attention_utils
.
_get_unpad_data
=
get_unpad_data
return
import
transformers.models
if
model_type
==
"cohere"
:
transformers
.
models
.
cohere
.
modeling_cohere
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"falcon"
:
transformers
.
models
.
falcon
.
modeling_falcon
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"gemma"
:
transformers
.
models
.
gemma
.
modeling_gemma
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"gemma2"
:
transformers
.
models
.
gemma2
.
modeling_gemma2
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"llama"
:
transformers
.
models
.
llama
.
modeling_llama
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"mistral"
:
transformers
.
models
.
mistral
.
modeling_mistral
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"phi"
:
transformers
.
models
.
phi
.
modeling_phi
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"phi3"
:
transformers
.
models
.
phi3
.
modeling_phi3
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"qwen2"
:
transformers
.
models
.
qwen2
.
modeling_qwen2
.
_get_unpad_data
=
get_unpad_data
elif
model_type
==
"starcoder2"
:
transformers
.
models
.
starcoder2
.
modeling_starcoder2
.
_get_unpad_data
=
get_unpad_data
def
configure_packing
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
def
configure_packing
(
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
not
is_trainable
or
not
model_args
.
block_diag_attn
:
return
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
in
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
:
_patch_for_block_diag_attn
(
model_type
)
check_version
(
"transformers>=4.43.0,<=4.46.1"
)
transformers
.
modeling_flash_attention_utils
.
_get_unpad_data
=
get_unpad_data
logger
.
info_rank0
(
"Using block diagonal attention for sequence packing without cross-attention."
)
else
:
raise
ValueError
(
"Current model does not support block diagonal attention."
)
src/llamafactory/model/model_utils/quantization.py
View file @
8293100a
...
...
@@ -26,11 +26,10 @@ from datasets import load_dataset
from
transformers
import
BitsAndBytesConfig
,
EetqConfig
,
GPTQConfig
,
HqqConfig
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.constants
import
FILEEXT2TYPE
from
...extras.misc
import
get_current_device
from
...extras.misc
import
check_version
,
get_current_device
if
TYPE_CHECKING
:
...
...
@@ -118,15 +117,15 @@ def configure_quantization(
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
:
require
_version
(
"auto_gptq>=0.5.0"
,
"To fix: pip install auto_gptq>=0.5.0"
)
check
_version
(
"auto_gptq>=0.5.0"
,
mandatory
=
True
)
quantization_config
.
pop
(
"disable_exllama"
,
None
)
# remove deprecated args
quantization_config
[
"use_exllama"
]
=
False
# disable exllama
if
quant_method
==
QuantizationMethod
.
AWQ
:
require
_version
(
"autoawq"
,
"To fix: pip install autoawq"
)
check
_version
(
"autoawq"
,
mandatory
=
True
)
if
quant_method
==
QuantizationMethod
.
AQLM
:
require
_version
(
"aqlm>=1.1.0"
,
"To fix: pip install aqlm[gpu]>=1.1.0"
)
check
_version
(
"aqlm>=1.1.0"
,
mandatory
=
True
)
quantization_config
[
"bits"
]
=
2
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
...
...
@@ -136,8 +135,8 @@ def configure_quantization(
if
model_args
.
export_quantization_bit
not
in
[
8
,
4
,
3
,
2
]:
raise
ValueError
(
"AutoGPTQ only accepts 2/3/4/8-bit quantization."
)
require
_version
(
"optimum>=1.17.0"
,
"To fix: pip install optimum>=1.17.0"
)
require
_version
(
"auto_gptq>=0.5.0"
,
"To fix: pip install auto_gptq>=0.5.0"
)
check
_version
(
"optimum>=1.17.0"
,
mandatory
=
True
)
check
_version
(
"auto_gptq>=0.5.0"
,
mandatory
=
True
)
from
accelerate.utils
import
get_max_memory
if
getattr
(
config
,
"model_type"
,
None
)
==
"chatglm"
:
...
...
@@ -154,10 +153,10 @@ def configure_quantization(
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
if
model_args
.
quantization_method
==
QuantizationMethod
.
BITS_AND_BYTES
.
value
:
if
model_args
.
quantization_bit
==
8
:
require
_version
(
"bitsandbytes>=0.37.0"
,
"To fix: pip install bitsandbytes>=0.37.0"
)
check
_version
(
"bitsandbytes>=0.37.0"
,
mandatory
=
True
)
init_kwargs
[
"quantization_config"
]
=
BitsAndBytesConfig
(
load_in_8bit
=
True
)
elif
model_args
.
quantization_bit
==
4
:
require
_version
(
"bitsandbytes>=0.39.0"
,
"To fix: pip install bitsandbytes>=0.39.0"
)
check
_version
(
"bitsandbytes>=0.39.0"
,
mandatory
=
True
)
init_kwargs
[
"quantization_config"
]
=
BitsAndBytesConfig
(
load_in_4bit
=
True
,
bnb_4bit_compute_dtype
=
model_args
.
compute_dtype
,
...
...
@@ -175,7 +174,7 @@ def configure_quantization(
if
model_args
.
quantization_bit
!=
4
:
raise
ValueError
(
"Only 4-bit quantized model can use fsdp+qlora or auto device map."
)
require
_version
(
"bitsandbytes>=0.43.0"
,
"To fix: pip install bitsandbytes>=0.43.0"
)
check
_version
(
"bitsandbytes>=0.43.0"
,
mandatory
=
True
)
else
:
init_kwargs
[
"device_map"
]
=
{
""
:
get_current_device
()}
# change auto device map for inference
...
...
@@ -187,7 +186,7 @@ def configure_quantization(
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
raise
ValueError
(
"HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP."
)
require
_version
(
"hqq"
,
"To fix: pip install hqq"
)
check
_version
(
"hqq"
,
mandatory
=
True
)
init_kwargs
[
"quantization_config"
]
=
HqqConfig
(
nbits
=
model_args
.
quantization_bit
,
quant_zero
=
False
,
quant_scale
=
False
,
axis
=
0
)
# use ATEN kernel (axis=0) for performance
...
...
@@ -199,6 +198,6 @@ def configure_quantization(
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
raise
ValueError
(
"EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP."
)
require
_version
(
"eetq"
,
"To fix: pip install eetq"
)
check
_version
(
"eetq"
,
mandatory
=
True
)
init_kwargs
[
"quantization_config"
]
=
EetqConfig
()
logger
.
info_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with EETQ."
)
src/llamafactory/model/model_utils/unsloth.py
View file @
8293100a
...
...
@@ -39,7 +39,7 @@ def _get_unsloth_kwargs(
"device_map"
:
{
""
:
get_current_device
()},
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
),
"fix_tokenizer"
:
False
,
"trust_remote_code"
:
Tru
e
,
"trust_remote_code"
:
model_args
.
trust_remote_cod
e
,
"use_gradient_checkpointing"
:
"unsloth"
,
}
...
...
src/llamafactory/model/model_utils/visual.py
View file @
8293100a
...
...
@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
,
Set
,
Tuple
,
Union
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
,
Union
import
torch
import
transformers
...
...
@@ -35,6 +36,40 @@ logger = logging.get_logger(__name__)
transformers_logger
=
transformers
.
utils
.
logging
.
get_logger
(
__name__
)
@
dataclass
class
CompositeModel
:
model_type
:
str
projector_key
:
str
vision_model_keys
:
List
[
str
]
language_model_keys
:
List
[
str
]
def
get_projector
(
self
,
module
:
"torch.nn.Module"
)
->
"torch.nn.Module"
:
for
key
in
self
.
projector_key
.
split
(
"."
):
module
=
getattr
(
module
,
key
)
return
module
COMPOSITE_MODELS
:
Dict
[
str
,
"CompositeModel"
]
=
{}
def
_register_composite_model
(
model_type
:
str
,
projector_key
:
Optional
[
str
]
=
None
,
vision_model_keys
:
Optional
[
List
[
str
]]
=
None
,
language_model_keys
:
Optional
[
List
[
str
]]
=
None
,
):
projector_key
=
projector_key
or
"multi_modal_projector"
vision_model_keys
=
vision_model_keys
or
[
"vision_tower"
]
language_model_keys
=
language_model_keys
or
[
"language_model"
]
COMPOSITE_MODELS
[
model_type
]
=
CompositeModel
(
model_type
=
model_type
,
projector_key
=
projector_key
,
vision_model_keys
=
vision_model_keys
,
language_model_keys
=
language_model_keys
,
)
class
LlavaMultiModalProjectorForYiVL
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
:
"LlavaConfig"
)
->
None
:
super
().
__init__
()
...
...
@@ -92,10 +127,8 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if
getattr
(
model
,
"quantization_method"
,
None
):
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]:
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
mm_projector
:
"torch.nn.Module"
=
getattr
(
getattr
(
model
,
"visual"
),
"merger"
)
if
model_type
in
COMPOSITE_MODELS
:
mm_projector
=
COMPOSITE_MODELS
[
model_type
].
get_projector
(
model
)
else
:
return
...
...
@@ -107,15 +140,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
r
"""
Patches VLMs before loading them.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
,
]:
# required for ds zero3 and valuehead models
if
getattr
(
config
,
"text_config"
,
None
)
and
not
getattr
(
config
,
"hidden_size"
,
None
):
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
...
...
@@ -129,19 +155,21 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
forbidden_modules
=
set
()
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]
:
if
model_type
in
COMPOSITE_MODELS
:
if
finetuning_args
.
freeze_vision_tower
:
forbidden_modules
.
add
(
"vision_tower"
)
if
finetuning_args
.
train_mm_proj_only
:
forbidden_modules
.
add
(
"language_model"
)
vision_model_keys
=
COMPOSITE_MODELS
[
model_type
].
vision_model_keys
logger
.
info_rank0
(
f
"Set vision model not trainable:
{
vision_model_keys
}
."
)
forbidden_modules
.
update
(
vision_model_keys
)
elif
model_type
==
"qwen2_vl"
:
if
finetuning_args
.
freeze_vision_tower
:
forbidden_modules
.
add
(
"visual"
)
if
finetuning_args
.
freeze_multi_modal_projector
:
projector_key
=
COMPOSITE_MODELS
[
model_type
].
projector_key
logger
.
info_rank0
(
f
"Set multi model projector not trainable:
{
projector_key
}
."
)
forbidden_modules
.
add
(
projector_key
)
if
finetuning_args
.
train_mm_proj_only
:
raise
ValueError
(
"Qwen2-VL models do not support `train_mm_proj_only`."
)
language_model_keys
=
COMPOSITE_MODELS
[
model_type
].
language_model_keys
logger
.
info_rank0
(
f
"Set language model not trainable:
{
language_model_keys
}
."
)
forbidden_modules
.
update
(
language_model_keys
)
return
forbidden_modules
...
...
@@ -188,19 +216,73 @@ def patch_target_modules(
Freezes vision tower for VLM LoRA tuning.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
vit_model_type
=
getattr
(
getattr
(
config
,
"vision_config"
,
None
),
"model_type"
,
None
)
if
finetuning_args
.
freeze_vision_tower
:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]
:
return
"^(?!.*vision_tower).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"mllama"
:
return
"^(?!.*vision_model).*(?:{}).*"
.
format
(
"|"
.
join
(
target_module
s
)
)
elif
model_type
==
"qwen2_vl"
:
return
"^(?!.*vis
ual).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
if
model_type
in
COMPOSITE_MODELS
:
vision_model_keys
=
COMPOSITE_MODELS
[
model_type
].
vision_model_keys
logger
.
info_rank0
(
f
"Set vision model not trainable:
{
vision_model_keys
}
."
)
vision_model_keys
=
"|"
.
join
(
vision_model_key
s
)
target_modules
=
"|"
.
join
(
target_modules
)
return
f
"^(?!.*
{
vis
ion_model_keys
}
).*(?:
{
target_modules
}
).*"
else
:
return
target_modules
else
:
if
model_type
==
"qwen2_vl"
:
if
model_type
==
"qwen2_vl"
:
# avoid attaching lora to Conv3D layer
return
"^(?!.*patch_embed).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"pixtral"
:
elif
vit_
model_type
==
"pixtral"
:
return
"^(?!.*patch_conv).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
return
target_modules
_register_composite_model
(
model_type
=
"llava"
,
)
_register_composite_model
(
model_type
=
"llava_next"
,
)
_register_composite_model
(
model_type
=
"llava_next_video"
,
)
_register_composite_model
(
model_type
=
"minicpmv"
,
vision_model_keys
=
[
"vpm"
],
language_model_keys
=
[
"llm"
],
)
_register_composite_model
(
model_type
=
"minicpmo"
,
vision_model_keys
=
[
"vpm"
,
"apm"
,
"resampler"
,
"tts"
],
language_model_keys
=
[
"llm"
],
)
_register_composite_model
(
model_type
=
"paligemma"
,
)
_register_composite_model
(
model_type
=
"video_llava"
,
)
_register_composite_model
(
model_type
=
"mllama"
,
vision_model_keys
=
[
"vision_model"
],
)
_register_composite_model
(
model_type
=
"qwen2_vl"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
)
src/llamafactory/model/patcher.py
View file @
8293100a
...
...
@@ -24,6 +24,7 @@ from transformers.modeling_utils import is_fsdp_enabled
from
..extras
import
logging
from
..extras.misc
import
infer_optim_dtype
from
..extras.packages
import
is_transformers_version_greater_than
from
.model_utils.attention
import
configure_attn_implementation
,
print_attn_implementation
from
.model_utils.checkpointing
import
prepare_model_for_training
from
.model_utils.embedding
import
resize_embedding_layer
...
...
@@ -96,7 +97,7 @@ def patch_config(
configure_quantization
(
config
,
tokenizer
,
model_args
,
init_kwargs
)
configure_moe
(
config
,
model_args
,
is_trainable
)
configure_visual_model
(
config
)
configure_packing
(
config
,
model_args
,
is_trainable
)
configure_packing
(
model_args
,
is_trainable
)
if
model_args
.
use_cache
and
not
is_trainable
:
setattr
(
config
,
"use_cache"
,
True
)
...
...
@@ -110,9 +111,16 @@ def patch_config(
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen2"
and
is_trainable
and
model_args
.
flash_attn
==
"fa2"
:
setattr
(
config
,
"use_cache"
,
False
)
# qwen2 does not support use_cache when using flash attn
if
getattr
(
config
,
"model_type"
,
None
)
==
"minicpmo"
:
setattr
(
config
,
"init_audio"
,
False
)
setattr
(
config
,
"init_tts"
,
False
)
if
"LlavaLlamaForCausalLM"
in
getattr
(
config
,
"architectures"
,
[]):
raise
ValueError
(
"Please download llava models with hf-compatible format: https://huggingface.co/llava-hf"
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"internlm3"
and
not
is_transformers_version_greater_than
(
"4.47.1"
):
raise
RuntimeError
(
"InternLM3 model requires transformers>=4.47.1, please upgrade it."
)
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs
[
"low_cpu_mem_usage"
]
=
model_args
.
low_cpu_mem_usage
and
(
not
is_deepspeed_zero3_enabled
())
...
...
@@ -145,7 +153,9 @@ def patch_model(
):
gen_config
.
do_sample
=
True
if
"GenerationMixin"
not
in
str
(
model
.
generate
.
__func__
):
if
getattr
(
model
.
config
,
"model_type"
,
None
)
not
in
[
"minicpmv"
,
"minicpmo"
]
and
"GenerationMixin"
not
in
str
(
model
.
generate
.
__func__
):
model
.
generate
=
MethodType
(
PreTrainedModel
.
generate
,
model
)
if
add_valuehead
:
...
...
src/llamafactory/train/callbacks.py
View file @
8293100a
...
...
@@ -35,17 +35,20 @@ from typing_extensions import override
from
..extras
import
logging
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.misc
import
get_peak_memory
from
..extras.misc
import
get_peak_memory
,
use_ray
if
is_safetensors_available
():
from
safetensors
import
safe_open
from
safetensors.torch
import
save_file
if
TYPE_CHECKING
:
from
transformers
import
TrainerControl
,
TrainerState
,
TrainingArguments
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a checkpoint save.
"""
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
fix_valuehead_checkpoint
(
...
...
@@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the beginning of training.
"""
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
...
...
@@ -197,7 +194,7 @@ class LogCallback(TrainerCallback):
self
.
do_train
=
False
# Web UI
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
self
.
webui_mode
:
if
self
.
webui_mode
and
not
use_ray
()
:
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
self
.
logger_handler
=
logging
.
LoggerHandler
(
os
.
environ
.
get
(
"LLAMABOARD_WORKDIR"
))
logging
.
add_handler
(
self
.
logger_handler
)
...
...
@@ -242,7 +239,7 @@ class LogCallback(TrainerCallback):
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
args
.
overwrite_output_dir
):
logger
.
warning_once
(
"Previous trainer log in this folder will be deleted."
)
logger
.
warning_
rank0_
once
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
@
override
...
...
@@ -348,3 +345,51 @@ class LogCallback(TrainerCallback):
remaining_time
=
self
.
remaining_time
,
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
class
ReporterCallback
(
TrainerCallback
):
r
"""
A callback for reporting training status to external logger.
"""
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
model_args
=
model_args
self
.
data_args
=
data_args
self
.
finetuning_args
=
finetuning_args
self
.
generating_args
=
generating_args
os
.
environ
[
"WANDB_PROJECT"
]
=
os
.
getenv
(
"WANDB_PROJECT"
,
"llamafactory"
)
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
state
.
is_world_process_zero
:
return
if
"wandb"
in
args
.
report_to
:
import
wandb
wandb
.
config
.
update
(
{
"model_args"
:
self
.
model_args
.
to_dict
(),
"data_args"
:
self
.
data_args
.
to_dict
(),
"finetuning_args"
:
self
.
finetuning_args
.
to_dict
(),
"generating_args"
:
self
.
generating_args
.
to_dict
(),
}
)
if
self
.
finetuning_args
.
use_swanlab
:
import
swanlab
# type: ignore
swanlab
.
config
.
update
(
{
"model_args"
:
self
.
model_args
.
to_dict
(),
"data_args"
:
self
.
data_args
.
to_dict
(),
"finetuning_args"
:
self
.
finetuning_args
.
to_dict
(),
"generating_args"
:
self
.
generating_args
.
to_dict
(),
}
)
src/llamafactory/train/dpo/trainer.py
View file @
8293100a
...
...
@@ -19,7 +19,7 @@ import warnings
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
...
...
@@ -29,9 +29,9 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
from
...extras.packages
import
is_transformers_version_equal_to_4_46
,
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
if
TYPE_CHECKING
:
...
...
@@ -50,6 +50,9 @@ class CustomDPOTrainer(DPOTrainer):
disable_dropout
:
bool
=
True
,
**
kwargs
,
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
if
disable_dropout
:
disable_dropout_in_model
(
model
)
if
ref_model
is
not
None
:
...
...
@@ -79,6 +82,7 @@ class CustomDPOTrainer(DPOTrainer):
self
.
simpo_gamma
=
finetuning_args
.
simpo_gamma
Trainer
.
__init__
(
self
,
model
=
model
,
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
if
not
hasattr
(
self
,
"accelerator"
):
raise
AttributeError
(
"Please update `transformers`."
)
...
...
@@ -97,9 +101,6 @@ class CustomDPOTrainer(DPOTrainer):
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
pissa_convert
:
self
.
callback_handler
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
...
...
@@ -119,6 +120,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
r
"""
...
...
@@ -185,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities.
"""
if
self
.
finetuning_args
.
use_ref_model
:
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
all_logits
:
"torch.Tensor"
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
])
...
...
@@ -266,17 +274,18 @@ class CustomDPOTrainer(DPOTrainer):
return
losses
.
mean
(),
metrics
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
get
(
"num_items_in_batch"
):
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment