Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
269 additions
and
353 deletions
+269
-353
src/llamafactory/model/model_utils/checkpointing.py
src/llamafactory/model/model_utils/checkpointing.py
+17
-18
src/llamafactory/model/model_utils/embedding.py
src/llamafactory/model/model_utils/embedding.py
+1
-3
src/llamafactory/model/model_utils/kv_cache.py
src/llamafactory/model/model_utils/kv_cache.py
+44
-0
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+7
-34
src/llamafactory/model/model_utils/longlora.py
src/llamafactory/model/model_utils/longlora.py
+19
-19
src/llamafactory/model/model_utils/misc.py
src/llamafactory/model/model_utils/misc.py
+5
-9
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+3
-5
src/llamafactory/model/model_utils/packing.py
src/llamafactory/model/model_utils/packing.py
+7
-8
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+9
-15
src/llamafactory/model/model_utils/rope.py
src/llamafactory/model/model_utils/rope.py
+1
-1
src/llamafactory/model/model_utils/unsloth.py
src/llamafactory/model/model_utils/unsloth.py
+9
-15
src/llamafactory/model/model_utils/valuehead.py
src/llamafactory/model/model_utils/valuehead.py
+3
-4
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+48
-61
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+19
-35
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+16
-26
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+24
-41
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+3
-3
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+23
-37
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+3
-3
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+8
-16
No files found.
src/llamafactory/model/model_utils/checkpointing.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
...
...
@@ -21,7 +21,7 @@
import
inspect
from
functools
import
WRAPPER_ASSIGNMENTS
,
partial
,
wraps
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
...
...
@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
def
get_unsloth_gradient_checkpointing_func
()
->
Callable
:
class
UnslothGradientCheckpointing
(
torch
.
autograd
.
Function
):
r
"""
Saves VRAM by smartly offloading to RAM.
"""
r
"""Saves VRAM by smartly offloading to RAM."""
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_fwd
...
...
@@ -77,13 +75,14 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
def
get_custom_gradient_checkpointing_func
(
gradient_checkpointing_func
:
Callable
)
->
Callable
:
r
"""
Only applies gradient checkpointing to trainable layers.
"""
r
"""Only applies gradient checkpointing to trainable layers."""
@
wraps
(
gradient_checkpointing_func
,
assigned
=
WRAPPER_ASSIGNMENTS
+
(
"__self__"
,))
def
custom_gradient_checkpointing_func
(
func
:
Callable
,
*
args
:
Union
[
"torch.Tensor"
,
Any
],
**
kwargs
):
module
:
"torch.nn.Module"
=
func
.
__self__
if
isinstance
(
func
,
partial
):
module
:
torch
.
nn
.
Module
=
func
.
func
.
__self__
else
:
module
:
torch
.
nn
.
Module
=
func
.
__self__
has_grad
=
False
if
any
(
param
.
requires_grad
for
param
in
module
.
parameters
()):
...
...
@@ -103,11 +102,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def
_gradient_checkpointing_enable
(
self
:
"PreTrainedModel"
,
gradient_checkpointing_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
gradient_checkpointing_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
use_unsloth_gc
:
bool
=
False
,
)
->
None
:
r
"""
Activates gradient checkpointing for the current model.
r
"""Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
...
...
@@ -134,17 +132,18 @@ def _gradient_checkpointing_enable(
def
_fp32_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
T
uple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
module
:
"torch.nn.Module"
,
args
:
t
uple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
return
output
.
to
(
torch
.
float32
)
def
prepare_model_for_training
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
)
->
None
:
r
"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
r
"""Prepare the model before training.
Include:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32.
"""
if
model_args
.
upcast_layernorm
:
logger
.
info_rank0
(
"Upcasting layernorm weights in float32."
)
...
...
src/llamafactory/model/model_utils/embedding.py
View file @
7ea81099
...
...
@@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
def
resize_embedding_layer
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
Resize token embeddings.
"""
r
"""Resize token embeddings."""
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# type: ignore
...
...
src/llamafactory/model/model_utils/kv_cache.py
0 → 100644
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
from
...extras
import
logging
logger
=
logging
.
get_logger
(
__name__
)
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
from
...hparams
import
ModelArguments
def
configure_kv_cache
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
not
is_trainable
:
setattr
(
config
,
"use_cache"
,
model_args
.
use_cache
)
if
hasattr
(
config
,
"text_config"
):
setattr
(
config
.
text_config
,
"use_cache"
,
model_args
.
use_cache
)
if
model_args
.
use_cache
:
logger
.
info_rank0
(
"KV cache is enabled for faster generation."
)
else
:
logger
.
info_rank0
(
"KV cache is disabled."
)
else
:
setattr
(
config
,
"use_cache"
,
False
)
if
hasattr
(
config
,
"text_config"
):
setattr
(
config
.
text_config
,
"use_cache"
,
False
)
logger
.
info_rank0
(
"KV cache is disabled during training."
)
src/llamafactory/model/model_utils/liger_kernel.py
View file @
7ea81099
...
...
@@ -27,39 +27,6 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
def
apply_liger_kernel_to_qwen2_5_vl
(
rope
:
bool
=
True
,
cross_entropy
:
bool
=
False
,
fused_linear_cross_entropy
:
bool
=
True
,
rms_norm
:
bool
=
True
,
swiglu
:
bool
=
True
,
)
->
None
:
from
liger_kernel.transformers
import
LigerCrossEntropyLoss
,
LigerRMSNorm
,
LigerSwiGLUMLP
from
liger_kernel.transformers.model.qwen2_vl
import
lce_forward
as
qwen2_vl_lce_forward
from
liger_kernel.transformers.qwen2vl_mrope
import
liger_multimodal_rotary_pos_emb
from
transformers.models.qwen2_5_vl
import
modeling_qwen2_5_vl
def
get_dtype
(
self
:
"modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel"
):
return
self
.
dtype
modeling_qwen2_5_vl
.
Qwen2_5_VisionTransformerPretrainedModel
.
get_dtype
=
get_dtype
if
rope
:
modeling_qwen2_5_vl
.
apply_multimodal_rotary_pos_emb
=
liger_multimodal_rotary_pos_emb
if
rms_norm
:
modeling_qwen2_5_vl
.
Qwen2RMSNorm
=
LigerRMSNorm
if
cross_entropy
:
modeling_qwen2_5_vl
.
CrossEntropyLoss
=
LigerCrossEntropyLoss
if
fused_linear_cross_entropy
:
modeling_qwen2_5_vl
.
Qwen2_5_VLForConditionalGeneration
.
forward
=
qwen2_vl_lce_forward
if
swiglu
:
modeling_qwen2_5_vl
.
Qwen2MLP
=
LigerSwiGLUMLP
def
apply_liger_kernel
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
...
...
@@ -74,6 +41,12 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma
as
apply_liger_kernel
elif
model_type
==
"gemma2"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma2
as
apply_liger_kernel
elif
model_type
==
"gemma3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
elif
model_type
==
"gemma3_text"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
if
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"llama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
elif
model_type
==
"mistral"
:
...
...
@@ -89,7 +62,7 @@ def apply_liger_kernel(
elif
model_type
==
"qwen2_vl"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
elif
model_type
==
"qwen2_5_vl"
:
apply_
liger_kernel
=
apply_liger_kernel_to_qwen2_5_vl
from
liger_kernel
.transformers
import
apply_liger_kernel_to_qwen2_5_vl
as
apply_liger_kernel
else
:
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
return
...
...
src/llamafactory/model/model_utils/longlora.py
View file @
7ea81099
# Copyright 202
4
EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
# Copyright 202
5
EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
#
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
...
...
@@ -18,7 +18,7 @@
# limitations under the License.
import
math
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -54,14 +54,14 @@ def llama_attention_forward(
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
position_embeddings
:
Optional
[
T
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
position_embeddings
:
Optional
[
t
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
**
kwargs
,
)
->
T
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
T
uple
[
"torch.Tensor"
]]]:
)
->
t
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
t
uple
[
"torch.Tensor"
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"
torch.Tensor
"
=
self
.
q_proj
(
hidden_states
)
key_states
:
"
torch.Tensor
"
=
self
.
k_proj
(
hidden_states
)
value_states
:
"
torch.Tensor
"
=
self
.
v_proj
(
hidden_states
)
query_states
:
torch
.
Tensor
=
self
.
q_proj
(
hidden_states
)
key_states
:
torch
.
Tensor
=
self
.
k_proj
(
hidden_states
)
value_states
:
torch
.
Tensor
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
...
...
@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
position_embeddings
:
Optional
[
T
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
position_embeddings
:
Optional
[
t
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
**
kwargs
,
)
->
T
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
T
uple
[
"torch.Tensor"
]]]:
)
->
t
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
t
uple
[
"torch.Tensor"
]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions
=
False
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"
torch.Tensor
"
=
self
.
q_proj
(
hidden_states
)
key_states
:
"
torch.Tensor
"
=
self
.
k_proj
(
hidden_states
)
value_states
:
"
torch.Tensor
"
=
self
.
v_proj
(
hidden_states
)
query_states
:
torch
.
Tensor
=
self
.
q_proj
(
hidden_states
)
key_states
:
torch
.
Tensor
=
self
.
k_proj
(
hidden_states
)
value_states
:
torch
.
Tensor
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
...
...
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if
is_transformers_version_greater_than
(
"4.43.0"
):
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
attn_output
:
"
torch.Tensor
"
=
_flash_attention_forward
(
attn_output
:
torch
.
Tensor
=
_flash_attention_forward
(
query_states
,
key_states
,
value_states
,
...
...
@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
is_causal
=
self
.
is_causal
,
)
else
:
attn_output
:
"
torch.Tensor
"
=
self
.
_flash_attention_forward
(
attn_output
:
torch
.
Tensor
=
self
.
_flash_attention_forward
(
query_states
,
key_states
,
value_states
,
attention_mask
,
query_states
.
size
(
1
),
dropout
=
dropout_rate
)
...
...
@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
position_embeddings
:
Optional
[
T
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
position_embeddings
:
Optional
[
t
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
**
kwargs
,
)
->
T
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
T
uple
[
"torch.Tensor"
]]]:
)
->
t
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
t
uple
[
"torch.Tensor"
]]]:
if
output_attentions
:
transformers_logger
.
warning_once
(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
...
...
@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"
torch.Tensor
"
=
self
.
q_proj
(
hidden_states
)
key_states
:
"
torch.Tensor
"
=
self
.
k_proj
(
hidden_states
)
value_states
:
"
torch.Tensor
"
=
self
.
v_proj
(
hidden_states
)
query_states
:
torch
.
Tensor
=
self
.
q_proj
(
hidden_states
)
key_states
:
torch
.
Tensor
=
self
.
k_proj
(
hidden_states
)
value_states
:
torch
.
Tensor
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
...
...
src/llamafactory/model/model_utils/misc.py
View file @
7ea81099
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
from
...extras
import
logging
from
.visual
import
COMPOSITE_MODELS
...
...
@@ -25,10 +25,8 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
List
[
str
]:
r
"""
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
list
[
str
]:
r
"""Find all available modules to apply LoRA, GaLore or APOLLO."""
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
forbidden_modules
=
{
"lm_head"
}
if
model_type
==
"chatglm"
:
...
...
@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
return
list
(
module_names
)
def
find_expanded_modules
(
model
:
"PreTrainedModel"
,
target_modules
:
List
[
str
],
num_layer_trainable
:
int
)
->
List
[
str
]:
r
"""
Finds the modules in the expanded blocks to apply lora.
"""
def
find_expanded_modules
(
model
:
"PreTrainedModel"
,
target_modules
:
list
[
str
],
num_layer_trainable
:
int
)
->
list
[
str
]:
r
"""Find the modules in the expanded blocks to apply lora."""
num_layers
=
getattr
(
model
.
config
,
"num_hidden_layers"
,
None
)
if
not
num_layers
:
raise
ValueError
(
"Model was not supported."
)
...
...
src/llamafactory/model/model_utils/moe.py
View file @
7ea81099
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Sequence
from
typing
import
TYPE_CHECKING
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
...
...
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
def
_set_z3_leaf_modules
(
model
:
"PreTrainedModel"
,
leaf_modules
:
Sequence
[
"torch.nn.Module"
])
->
None
:
def
_set_z3_leaf_modules
(
model
:
"PreTrainedModel"
,
leaf_modules
:
list
[
"torch.nn.Module"
])
->
None
:
check_version
(
"deepspeed>=0.13.0"
)
from
deepspeed.utils
import
set_z3_leaf_modules
# type: ignore
...
...
@@ -34,9 +34,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
def
add_z3_leaf_module
(
model
:
"PreTrainedModel"
)
->
None
:
r
"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
r
"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
if
not
is_deepspeed_zero3_enabled
():
return
...
...
src/llamafactory/model/model_utils/packing.py
View file @
7ea81099
# Copyright 202
4
Musab Gultekin and the LlamaFactory team.
# Copyright 202
5
Musab Gultekin and the LlamaFactory team.
#
# This code is based on the Musab Gultekin's functionary library.
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
...
...
@@ -37,7 +37,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from
typing
import
TYPE_CHECKING
,
Tuple
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn.functional
as
F
...
...
@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
def
get_seqlens_in_batch
(
attention_mask
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
Gets the sequnce lengths in the current batch.
r
"""Get the sequnce lengths in the current batch.
e.g.
```python
...
...
@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
bsz
=
attention_mask
.
size
(
0
)
dtype
,
device
=
attention_mask
.
dtype
,
attention_mask
.
device
max_num
=
torch
.
max
(
attention_mask
).
item
()
counts
:
"
torch.Tensor
"
=
torch
.
zeros
((
bsz
,
max_num
),
dtype
=
dtype
,
device
=
device
)
counts
:
torch
.
Tensor
=
torch
.
zeros
((
bsz
,
max_num
),
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
max_num
):
counts
[:,
i
]
=
torch
.
sum
(
attention_mask
==
(
i
+
1
),
dim
=-
1
)
...
...
@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
return
seqlens
def
get_unpad_data
(
attention_mask
:
"torch.Tensor"
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
int
]:
r
"""
Prepares the indices and seqlens for flash attn varlen function.
def
get_unpad_data
(
attention_mask
:
"torch.Tensor"
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
int
]:
r
"""Prepare the indices and seqlens for flash attn varlen function.
Returns:
indices: indices of non-masked tokens from the flattened sequence.
...
...
@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
[0, 2, 5, 6, 8, 11]
3
```
"""
seqlens_in_batch
=
get_seqlens_in_batch
(
attention_mask
)
indices
=
torch
.
nonzero
(
attention_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
...
...
src/llamafactory/model/model_utils/quantization.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and Optimum library.
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
...
...
@@ -19,7 +19,7 @@
import
os
import
random
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
datasets
import
load_dataset
...
...
@@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
@
unique
class
QuantizationMethod
(
str
,
Enum
):
r
"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
r
"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BITS_AND_BYTES
=
"bitsandbytes"
GPTQ
=
"gptq"
...
...
@@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
HQQ
=
"hqq"
def
_get_quantization_dataset
(
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
)
->
List
[
Dict
[
str
,
Any
]]:
r
"""
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
def
_get_quantization_dataset
(
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
)
->
list
[
dict
[
str
,
Any
]]:
r
"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
if
os
.
path
.
isfile
(
model_args
.
export_quantization_dataset
):
data_path
=
FILEEXT2TYPE
.
get
(
model_args
.
export_quantization_dataset
.
split
(
"."
)[
-
1
],
None
)
data_files
=
model_args
.
export_quantization_dataset
...
...
@@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
raise
ValueError
(
"Cannot find satisfying example, considering decrease `export_quantization_maxlen`."
)
sample_idx
=
random
.
randint
(
0
,
len
(
dataset
)
-
1
)
sample
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
tokenizer
(
dataset
[
sample_idx
][
"text"
],
return_tensors
=
"pt"
)
sample
:
d
ict
[
str
,
torch
.
Tensor
]
=
tokenizer
(
dataset
[
sample_idx
][
"text"
],
return_tensors
=
"pt"
)
n_try
+=
1
if
sample
[
"input_ids"
].
size
(
1
)
>
maxlen
:
break
# TODO: fix large maxlen
...
...
@@ -101,11 +97,9 @@ def configure_quantization(
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
init_kwargs
:
D
ict
[
str
,
Any
],
init_kwargs
:
d
ict
[
str
,
Any
],
)
->
None
:
r
"""
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
r
"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
if
getattr
(
config
,
"quantization_config"
,
None
):
# ptq
if
model_args
.
quantization_bit
is
not
None
:
logger
.
warning_rank0
(
"`quantization_bit` will not affect on the PTQ-quantized models."
)
...
...
@@ -113,7 +107,7 @@ def configure_quantization(
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
raise
ValueError
(
"DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models."
)
quantization_config
:
D
ict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quantization_config
:
d
ict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
:
...
...
src/llamafactory/model/model_utils/rope.py
View file @
7ea81099
# Copyright 202
4
LMSYS and the LlamaFactory team.
# Copyright 202
5
LMSYS and the LlamaFactory team.
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# This code is inspired by the LMSYS's FastChat library.
...
...
src/llamafactory/model/model_utils/unsloth.py
View file @
7ea81099
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras.misc
import
get_current_device
...
...
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
def
_get_unsloth_kwargs
(
config
:
"PretrainedConfig"
,
model_name_or_path
:
str
,
model_args
:
"ModelArguments"
)
->
D
ict
[
str
,
Any
]:
)
->
d
ict
[
str
,
Any
]:
return
{
"model_name"
:
model_name_or_path
,
"max_seq_length"
:
model_args
.
model_max_length
or
4096
,
...
...
@@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
def
load_unsloth_pretrained_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
Optional
[
"PreTrainedModel"
]:
r
"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from
unsloth
import
FastLanguageModel
r
"""Optionally load pretrained model with unsloth. Used in training."""
from
unsloth
import
FastLanguageModel
# type: ignore
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
model_name_or_path
,
model_args
)
try
:
...
...
@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
def
get_unsloth_peft_model
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
peft_kwargs
:
D
ict
[
str
,
Any
]
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
peft_kwargs
:
d
ict
[
str
,
Any
]
)
->
"PreTrainedModel"
:
r
"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from
unsloth
import
FastLanguageModel
r
"""Get the peft model for the pretrained model with unsloth. Used in training."""
from
unsloth
import
FastLanguageModel
# type: ignore
unsloth_peft_kwargs
=
{
"model"
:
model
,
...
...
@@ -82,10 +78,8 @@ def get_unsloth_peft_model(
def
load_unsloth_peft_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
"PreTrainedModel"
:
r
"""
Loads peft model with unsloth. Used in both training and inference.
"""
from
unsloth
import
FastLanguageModel
r
"""Load peft model with unsloth. Used in both training and inference."""
from
unsloth
import
FastLanguageModel
# type: ignore
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
)
try
:
...
...
src/llamafactory/model/model_utils/valuehead.py
View file @
7ea81099
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
import
torch
from
transformers.utils
import
cached_file
...
...
@@ -30,9 +30,8 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
torch
.
Tensor
]:
r
"""
Loads value head parameters from Hugging Face Hub or local disk.
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
dict
[
str
,
torch
.
Tensor
]:
r
"""Load value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
...
...
src/llamafactory/model/model_utils/visual.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
...
...
@@ -16,7 +16,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
transformers
...
...
@@ -27,7 +27,7 @@ from ...extras import logging
if
TYPE_CHECKING
:
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
,
ProcessorMixin
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
FinetuningArguments
,
ModelArguments
...
...
@@ -40,9 +40,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
class
CompositeModel
:
model_type
:
str
projector_key
:
str
vision_model_keys
:
L
ist
[
str
]
language_model_keys
:
L
ist
[
str
]
lora_conflict_keys
:
L
ist
[
str
]
vision_model_keys
:
l
ist
[
str
]
language_model_keys
:
l
ist
[
str
]
lora_conflict_keys
:
l
ist
[
str
]
def
get_projector
(
self
,
module
:
"torch.nn.Module"
)
->
"torch.nn.Module"
:
for
key
in
self
.
projector_key
.
split
(
"."
):
...
...
@@ -51,16 +51,26 @@ class CompositeModel:
return
module
COMPOSITE_MODELS
:
D
ict
[
str
,
"CompositeModel"
]
=
{}
COMPOSITE_MODELS
:
d
ict
[
str
,
"CompositeModel"
]
=
{}
def
_register_composite_model
(
model_type
:
str
,
projector_key
:
Optional
[
str
]
=
None
,
vision_model_keys
:
Optional
[
L
ist
[
str
]]
=
None
,
language_model_keys
:
Optional
[
L
ist
[
str
]]
=
None
,
lora_conflict_keys
:
Optional
[
L
ist
[
str
]]
=
None
,
vision_model_keys
:
Optional
[
l
ist
[
str
]]
=
None
,
language_model_keys
:
Optional
[
l
ist
[
str
]]
=
None
,
lora_conflict_keys
:
Optional
[
l
ist
[
str
]]
=
None
,
):
r
"""Register a new composite model.
Args:
model_type: model type
projector_key: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
"""
COMPOSITE_MODELS
[
model_type
]
=
CompositeModel
(
model_type
=
model_type
,
projector_key
=
projector_key
or
"multi_modal_projector"
,
...
...
@@ -116,12 +126,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def
autocast_projector_dtype
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
)
->
None
:
r
"""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
r
"""Cast projector output to half precision for fine-tuning quantized VLMs."""
def
_mm_projector_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
T
uple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
module
:
"torch.nn.Module"
,
args
:
t
uple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
return
output
.
to
(
model_args
.
compute_dtype
)
...
...
@@ -137,9 +145,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
def
configure_visual_model
(
config
:
"PretrainedConfig"
)
->
None
:
r
"""
Patches VLMs before loading them.
"""
r
"""Patch VLMs before loading them."""
if
getattr
(
config
,
"text_config"
,
None
)
and
not
getattr
(
config
,
"hidden_size"
,
None
):
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
...
...
@@ -149,10 +155,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
def
get_forbidden_modules
(
config
:
"PretrainedConfig"
,
finetuning_args
:
"FinetuningArguments"
)
->
Set
[
str
]:
r
"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
def
get_forbidden_modules
(
config
:
"PretrainedConfig"
,
finetuning_args
:
"FinetuningArguments"
)
->
set
[
str
]:
r
"""Freeze vision tower and language model for VLM full/freeze tuning."""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
forbidden_modules
=
set
()
if
model_type
in
COMPOSITE_MODELS
:
...
...
@@ -174,47 +178,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return
forbidden_modules
def
get_image_seqlen
(
config
:
"PretrainedConfig"
)
->
int
:
r
"""
Computes the number of special tokens per image.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
==
"llava"
:
image_seqlen
=
(
config
.
vision_config
.
image_size
//
config
.
vision_config
.
patch_size
)
**
2
if
getattr
(
config
,
"vision_feature_select_strategy"
,
"default"
)
==
"full"
:
# add [CLS] token
image_seqlen
+=
1
elif
model_type
==
"paligemma"
:
image_seqlen
=
config
.
vision_config
.
num_image_tokens
else
:
image_seqlen
=
-
1
return
image_seqlen
def
get_patch_size
(
config
:
"PretrainedConfig"
,
processor
:
"ProcessorMixin"
)
->
int
:
r
"""
Computes the patch size of the vit.
"""
patch_size
=
getattr
(
config
.
vision_config
,
"patch_size"
,
getattr
(
processor
,
"patch_size"
,
-
1
))
return
patch_size
def
get_vision_feature_select_strategy
(
config
:
"PretrainedConfig"
,
processor
:
"ProcessorMixin"
)
->
int
:
r
"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy
=
getattr
(
config
,
"vision_feature_select_strategy"
,
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
)
return
vision_feature_select_strategy
def
patch_target_modules
(
model
:
"PreTrainedModel"
,
finetuning_args
:
"FinetuningArguments"
,
target_modules
:
Sequence
[
str
]
)
->
List
[
str
]:
r
"""
Freezes vision tower for VLM LoRA tuning.
"""
model
:
"PreTrainedModel"
,
finetuning_args
:
"FinetuningArguments"
,
target_modules
:
list
[
str
]
)
->
list
[
str
]:
r
"""Freeze vision tower for VLM LoRA tuning."""
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
in
COMPOSITE_MODELS
:
forbidden_modules
=
get_forbidden_modules
(
model
.
config
,
finetuning_args
)
...
...
@@ -231,6 +198,17 @@ def patch_target_modules(
return
target_modules
_register_composite_model
(
model_type
=
"gemma3"
,
)
_register_composite_model
(
model_type
=
"llama4"
,
vision_model_keys
=
[
"vision_model"
],
)
_register_composite_model
(
model_type
=
"llava"
,
)
...
...
@@ -285,6 +263,15 @@ _register_composite_model(
)
_register_composite_model
(
model_type
=
"qwen2_5_omni_thinker"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
,
"audio_tower"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
_register_composite_model
(
model_type
=
"qwen2_vl"
,
projector_key
=
"visual.merger"
,
...
...
src/llamafactory/model/patcher.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
peft
import
PeftModel
...
...
@@ -27,19 +27,14 @@ from ..extras.packages import is_transformers_version_greater_than
from
.model_utils.attention
import
configure_attn_implementation
,
print_attn_implementation
from
.model_utils.checkpointing
import
prepare_model_for_training
from
.model_utils.embedding
import
resize_embedding_layer
from
.model_utils.kv_cache
import
configure_kv_cache
from
.model_utils.longlora
import
configure_longlora
from
.model_utils.moe
import
add_z3_leaf_module
,
configure_moe
from
.model_utils.packing
import
configure_packing
from
.model_utils.quantization
import
configure_quantization
from
.model_utils.rope
import
configure_rope
from
.model_utils.valuehead
import
prepare_valuehead_model
from
.model_utils.visual
import
(
autocast_projector_dtype
,
configure_visual_model
,
get_image_seqlen
,
get_patch_size
,
get_vision_feature_select_strategy
,
)
from
.model_utils.visual
import
autocast_projector_dtype
,
configure_visual_model
if
TYPE_CHECKING
:
...
...
@@ -56,8 +51,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
if
"PreTrainedTokenizerBase"
not
in
str
(
tokenizer
.
_pad
.
__func__
):
tokenizer
.
_pad
=
MethodType
(
PreTrainedTokenizerBase
.
_pad
,
tokenizer
)
if
model_args
.
model_max_length
is
not
None
and
tokenizer
.
model_max_length
!=
model_args
.
model_max_length
:
tokenizer
.
model_max_length
=
model_args
.
model_max_length
if
model_args
.
model_max_length
is
not
None
and
tokenizer
.
model_max_length
<
model_args
.
model_max_length
:
tokenizer
.
model_max_length
=
model_args
.
model_max_length
# enlarge the tokenizer max length
if
model_args
.
new_special_tokens
is
not
None
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
...
...
@@ -72,28 +67,25 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
def
patch_processor
(
processor
:
"ProcessorMixin"
,
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
)
->
None
:
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
if
getattr
(
config
,
"vision_config"
,
None
)
is
not
None
:
# visual models
setattr
(
processor
,
"image_seqlen"
,
get_image_seqlen
(
config
))
setattr
(
processor
,
"patch_size"
,
get_patch_size
(
config
,
processor
))
setattr
(
processor
,
"image_max_pixels"
,
model_args
.
image_max_pixels
)
setattr
(
processor
,
"image_min_pixels"
,
model_args
.
image_min_pixels
)
setattr
(
processor
,
"video_max_pixels"
,
model_args
.
video_max_pixels
)
setattr
(
processor
,
"video_min_pixels"
,
model_args
.
video_min_pixels
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"vision_feature_select_strategy"
,
get_vision_feature_select_strategy
(
config
,
processor
))
setattr
(
processor
,
"image_max_pixels"
,
model_args
.
image_max_pixels
)
setattr
(
processor
,
"image_min_pixels"
,
model_args
.
image_min_pixels
)
setattr
(
processor
,
"image_do_pan_and_scan"
,
model_args
.
image_do_pan_and_scan
)
setattr
(
processor
,
"video_max_pixels"
,
model_args
.
video_max_pixels
)
setattr
(
processor
,
"video_min_pixels"
,
model_args
.
video_min_pixels
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"audio_sampling_rate"
,
model_args
.
audio_sampling_rate
)
def
patch_config
(
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
init_kwargs
:
D
ict
[
str
,
Any
],
init_kwargs
:
d
ict
[
str
,
Any
],
is_trainable
:
bool
,
)
->
None
:
if
model_args
.
compute_dtype
is
None
:
# priority: bf16 > fp16 > fp32
...
...
@@ -112,19 +104,13 @@ def patch_config(
configure_moe
(
config
,
model_args
,
is_trainable
)
configure_visual_model
(
config
)
configure_packing
(
model_args
,
is_trainable
)
if
model_args
.
use_cache
and
not
is_trainable
:
setattr
(
config
,
"use_cache"
,
True
)
logger
.
info_rank0
(
"Using KV cache for faster generation."
)
configure_kv_cache
(
config
,
model_args
,
is_trainable
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen"
:
setattr
(
config
,
"use_flash_attn"
,
model_args
.
flash_attn
==
"fa2"
)
for
dtype_name
,
dtype
in
[(
"fp16"
,
torch
.
float16
),
(
"bf16"
,
torch
.
bfloat16
),
(
"fp32"
,
torch
.
float32
)]:
setattr
(
config
,
dtype_name
,
model_args
.
compute_dtype
==
dtype
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen2"
and
is_trainable
and
model_args
.
flash_attn
==
"fa2"
:
setattr
(
config
,
"use_cache"
,
False
)
# qwen2 does not support use_cache when using flash attn
if
getattr
(
config
,
"model_type"
,
None
)
==
"minicpmo"
:
setattr
(
config
,
"init_audio"
,
True
)
setattr
(
config
,
"init_tts"
,
False
)
...
...
@@ -138,15 +124,13 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs
[
"low_cpu_mem_usage"
]
=
model_args
.
low_cpu_mem_usage
and
(
not
is_deepspeed_zero3_enabled
())
# cast data type of the model if:
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
# 2. quantization_bit is not None (qlora)
if
(
not
is_deepspeed_zero3_enabled
()
and
not
is_fsdp_enabled
())
or
model_args
.
quantization_bit
is
not
None
:
# do not cast data type of the model deepspeed zero3 without qlora
if
not
(
is_deepspeed_zero3_enabled
()
and
model_args
.
quantization_bit
is
None
):
init_kwargs
[
"torch_dtype"
]
=
model_args
.
compute_dtype
if
init_kwargs
[
"low_cpu_mem_usage"
]
:
# device map requires low_cpu_mem_usage=True
if
init_kwargs
[
"low_cpu_mem_usage"
]
and
not
is_fsdp_enabled
():
# fsdp does not need device map
if
"device_map"
not
in
init_kwargs
and
model_args
.
device_map
:
init_kwargs
[
"device_map"
]
=
model_args
.
device_map
init_kwargs
[
"device_map"
]
=
model_args
.
device_map
# device map requires low_cpu_mem_usage=True
if
init_kwargs
.
get
(
"device_map"
,
None
)
==
"auto"
:
init_kwargs
[
"offload_folder"
]
=
model_args
.
offload_folder
...
...
src/llamafactory/train/callbacks.py
View file @
7ea81099
...
...
@@ -19,7 +19,7 @@ import sys
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
datetime
import
timedelta
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
transformers
...
...
@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
def
fix_valuehead_checkpoint
(
model
:
"AutoModelForCausalLMWithValueHead"
,
output_dir
:
str
,
safe_serialization
:
bool
)
->
None
:
r
"""
r
"""Fix the valuehead checkpoint files.
The model is already unwrapped.
There are three cases:
...
...
@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
if
safe_serialization
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
SAFE_WEIGHTS_NAME
)
with
safe_open
(
path_to_checkpoint
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
state_dict
:
D
ict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
state_dict
:
d
ict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
else
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
D
ict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
state_dict
:
d
ict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
os
.
remove
(
path_to_checkpoint
)
decoder_state_dict
,
v_head_state_dict
=
{},
{}
...
...
@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
class
FixValueHeadModelCallback
(
TrainerCallback
):
r
"""
A callback for fixing the checkpoint for valuehead models.
"""
r
"""A callback for fixing the checkpoint for valuehead models."""
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
...
...
@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
class
SaveProcessorCallback
(
TrainerCallback
):
r
"""
A callback for saving the processor.
"""
r
"""A callback for saving the processor."""
def
__init__
(
self
,
processor
:
"ProcessorMixin"
)
->
None
:
self
.
processor
=
processor
...
...
@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
class
PissaConvertCallback
(
TrainerCallback
):
r
"""
A callback for converting the PiSSA adapter to a normal one.
"""
r
"""A callback for converting the PiSSA adapter to a normal one."""
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
...
...
@@ -166,20 +161,17 @@ class PissaConvertCallback(TrainerCallback):
model
.
save_pretrained
(
pissa_backup_dir
,
safe_serialization
=
args
.
save_safetensors
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
model
.
save_pretrained
(
pissa_convert_dir
,
safe_serialization
=
args
.
save_safetensors
,
convert_pissa_to_lora
=
pissa_init_dir
)
# TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
pissa_convert_dir
,
safe_serialization
=
args
.
save_safetensors
,
path_initial_model_for_weight_conversion
=
pissa_init_dir
,
)
model
.
load_adapter
(
pissa_backup_dir
,
"default"
,
is_trainable
=
True
)
model
.
set_adapter
(
"default"
)
if
"pissa_init"
in
model
.
peft_config
.
keys
():
# backward compatibility (peft<0.12.0)
model
.
delete_adapter
(
"pissa_init"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
class
LogCallback
(
TrainerCallback
):
r
"""
A callback for logging training and evaluation status.
"""
r
"""A callback for logging training and evaluation status."""
def
__init__
(
self
)
->
None
:
# Progress
...
...
@@ -188,7 +180,7 @@ class LogCallback(TrainerCallback):
self
.
max_steps
=
0
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
self
.
thread_pool
:
Optional
[
"
ThreadPoolExecutor
"
]
=
None
self
.
thread_pool
:
Optional
[
ThreadPoolExecutor
]
=
None
# Status
self
.
aborted
=
False
self
.
do_train
=
False
...
...
@@ -219,7 +211,7 @@ class LogCallback(TrainerCallback):
self
.
elapsed_time
=
str
(
timedelta
(
seconds
=
int
(
elapsed_time
)))
self
.
remaining_time
=
str
(
timedelta
(
seconds
=
int
(
remaining_time
)))
def
_write_log
(
self
,
output_dir
:
str
,
logs
:
D
ict
[
str
,
Any
])
->
None
:
def
_write_log
(
self
,
output_dir
:
str
,
logs
:
d
ict
[
str
,
Any
])
->
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINER_LOG
),
"a"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
logs
)
+
"
\n
"
)
...
...
@@ -348,9 +340,7 @@ class LogCallback(TrainerCallback):
class
ReporterCallback
(
TrainerCallback
):
r
"""
A callback for reporting training status to external logger.
"""
r
"""A callback for reporting training status to external logger."""
def
__init__
(
self
,
...
...
src/llamafactory/train/dpo/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
...
...
@@ -19,7 +19,7 @@ import warnings
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
torch
import
torch.nn.functional
as
F
...
...
@@ -128,16 +128,12 @@ class CustomDPOTrainer(DPOTrainer):
return
super
().
_get_train_sampler
()
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
r
"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return
Trainer
.
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
)
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
r
"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return
Trainer
.
get_batch_samples
(
self
,
*
args
,
**
kwargs
)
def
odds_ratio_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
r
"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
log_odds
=
(
chosen_logps
-
rejected_logps
)
-
(
torch
.
log1p
(
-
torch
.
exp
(
chosen_logps
))
-
torch
.
log1p
(
-
torch
.
exp
(
rejected_logps
))
)
...
...
@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
return
orpo_loss
def
simpo_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
Computes SimPO loss for batched log probabilities of the policy model.
"""
r
"""Compute SimPO loss for batched log probabilities of the policy model."""
pi_logratios
=
chosen_logps
-
rejected_logps
gamma_logratios
=
self
.
simpo_gamma
/
self
.
beta
logits
=
pi_logratios
-
gamma_logratios
...
...
@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logps
:
"torch.Tensor"
,
reference_chosen_logps
:
Optional
[
"torch.Tensor"
],
reference_rejected_logps
:
Optional
[
"torch.Tensor"
],
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
Computes loss for preference learning.
"""
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute loss for preference learning."""
if
not
self
.
finetuning_args
.
use_ref_model
:
if
self
.
loss_type
==
"orpo"
:
losses
=
self
.
odds_ratio_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
...
...
@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if
self
.
finetuning_args
.
use_ref_model
:
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
all_logits
:
"
torch.Tensor
"
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logits
:
torch
.
Tensor
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
])
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
all_logps
=
all_logps
/
valid_length
...
...
@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
@
override
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""
Computes log probabilities of the reference model.
"""
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""Compute log probabilities of the reference model."""
if
not
self
.
finetuning_args
.
use_ref_model
:
return
None
,
None
...
...
@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
def
get_batch_loss_metrics
(
self
,
model
:
"PreTrainedModel"
,
batch
:
D
ict
[
str
,
"torch.Tensor"
],
batch
:
d
ict
[
str
,
"torch.Tensor"
],
train_eval
:
Literal
[
"train"
,
"eval"
]
=
"train"
,
)
->
Tuple
[
"torch.Tensor"
,
Dict
[
str
,
"torch.Tensor"
]]:
r
"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
)
->
tuple
[
"torch.Tensor"
,
dict
[
str
,
"torch.Tensor"
]]:
r
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics
=
{}
(
policy_chosen_logps
,
...
...
@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Subclass and override to accept extra kwargs.
"""
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""Subclass and override to accept extra kwargs."""
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""
Log `logs` on the various objects watching training, including stored metrics.
"""
def
log
(
self
,
logs
:
dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
# Add averaged stored metrics to logs
...
...
src/llamafactory/train/dpo/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
...
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
...
...
@@ -38,7 +38,7 @@ def run_dpo(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/kto/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
...
...
@@ -19,7 +19,7 @@ import warnings
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
torch
from
transformers
import
Trainer
...
...
@@ -120,28 +120,22 @@ class CustomKTOTrainer(KTOTrainer):
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
r
"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
r
"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
Trainer
.
_get_train_sampler
(
self
)
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
r
"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return
Trainer
.
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
)
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
r
"""Replace the method of KTO Trainer with the one of the standard Trainer."""
return
Trainer
.
get_batch_samples
(
self
,
*
args
,
**
kwargs
)
@
override
def
forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
Runs forward pass and computes the log probabilities.
"""
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Run forward pass and computes the log probabilities."""
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
model_inputs
=
{
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
...
...
@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
D
ict
[
str
,
"torch.Tensor"
]
)
->
T
uple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
self
,
model
:
"PreTrainedModel"
,
batch
:
d
ict
[
str
,
"torch.Tensor"
]
)
->
t
uple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
target_logits
,
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
with
torch
.
no_grad
():
_
,
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
...
...
@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
@
override
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
Computes log probabilities of the reference model.
"""
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute log probabilities of the reference model."""
if
self
.
ref_model
is
None
:
ref_model
=
model
ref_context
=
self
.
accelerator
.
unwrap_model
(
model
).
disable_adapter
()
...
...
@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
def
get_batch_loss_metrics
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
)
->
Tuple
[
"torch.Tensor"
,
Dict
[
str
,
"torch.Tensor"
]]:
r
"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
batch
:
dict
[
str
,
"torch.Tensor"
],
)
->
tuple
[
"torch.Tensor"
,
dict
[
str
,
"torch.Tensor"
]]:
r
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics
=
{}
(
policy_chosen_logps
,
...
...
@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Subclass and override to accept extra kwargs.
"""
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""Subclass and override to accept extra kwargs."""
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""
Log `logs` on the various objects watching training, including stored metrics.
"""
def
log
(
self
,
logs
:
dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
...
...
@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"sum"
).
tolist
()
metric_dict
:
D
ict
[
str
,
float
]
=
dict
(
zip
(
key_list
,
metric_list
))
metric_dict
:
d
ict
[
str
,
float
]
=
dict
(
zip
(
key_list
,
metric_list
))
for
split
in
[
"chosen"
,
"rejected"
]:
# accumulate average metrics from sums and lengths
if
f
"count/
{
split
}
"
in
metric_dict
:
for
key
in
(
"rewards"
,
"logps"
,
"logits"
):
...
...
src/llamafactory/train/kto/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
...
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
...
...
@@ -37,7 +37,7 @@ def run_kto(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/ppo/ppo_utils.py
View file @
7ea81099
...
...
@@ -14,7 +14,7 @@
import
json
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
...
...
@@ -31,10 +31,8 @@ if TYPE_CHECKING:
from
trl
import
AutoModelForCausalLMWithValueHead
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
List
[
str
])
->
List
[
"torch.Tensor"
]:
r
"""
Gets reward scores from the API server.
"""
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
list
[
str
])
->
list
[
"torch.Tensor"
]:
r
"""Get reward scores from the API server."""
headers
=
{
"Content-Type"
:
"application/json"
}
payload
=
{
"model"
:
"model"
,
"messages"
:
messages
}
response
=
requests
.
post
(
server_url
,
json
=
payload
,
headers
=
headers
)
...
...
@@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
def
replace_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
target
:
Literal
[
"default"
,
"reward"
])
->
None
:
r
"""
Replaces the default/reward modules in the model. The model is already unwrapped.
"""
r
"""Replace the default/reward modules in the model. The model is already unwrapped."""
v_head_layer
=
model
.
v_head
.
summary
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# type: ignore
...
...
@@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_bias"
).
detach
().
clone
().
to
(
device
)
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
dict
[
str
,
"torch.Tensor"
]:
r
"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
layer_norm_params
=
{}
for
name
,
param
in
model
.
named_parameters
():
if
param
.
data
.
dtype
==
torch
.
float32
:
...
...
@@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
return
layer_norm_params
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
Dict
[
str
,
"torch.Tensor"
]]
=
None
)
->
None
:
r
"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
dict
[
str
,
"torch.Tensor"
]]
=
None
)
->
None
:
r
"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
for
name
,
param
in
model
.
named_parameters
():
if
name
in
layernorm_params
:
param
.
data
=
layernorm_params
[
name
]
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment