Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
269 additions
and
353 deletions
+269
-353
src/llamafactory/model/model_utils/checkpointing.py
src/llamafactory/model/model_utils/checkpointing.py
+17
-18
src/llamafactory/model/model_utils/embedding.py
src/llamafactory/model/model_utils/embedding.py
+1
-3
src/llamafactory/model/model_utils/kv_cache.py
src/llamafactory/model/model_utils/kv_cache.py
+44
-0
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+7
-34
src/llamafactory/model/model_utils/longlora.py
src/llamafactory/model/model_utils/longlora.py
+19
-19
src/llamafactory/model/model_utils/misc.py
src/llamafactory/model/model_utils/misc.py
+5
-9
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+3
-5
src/llamafactory/model/model_utils/packing.py
src/llamafactory/model/model_utils/packing.py
+7
-8
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+9
-15
src/llamafactory/model/model_utils/rope.py
src/llamafactory/model/model_utils/rope.py
+1
-1
src/llamafactory/model/model_utils/unsloth.py
src/llamafactory/model/model_utils/unsloth.py
+9
-15
src/llamafactory/model/model_utils/valuehead.py
src/llamafactory/model/model_utils/valuehead.py
+3
-4
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+48
-61
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+19
-35
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+16
-26
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+24
-41
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+3
-3
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+23
-37
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+3
-3
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+8
-16
No files found.
src/llamafactory/model/model_utils/checkpointing.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's Transformers and PEFT library,
# This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
import
inspect
import
inspect
from
functools
import
WRAPPER_ASSIGNMENTS
,
partial
,
wraps
from
functools
import
WRAPPER_ASSIGNMENTS
,
partial
,
wraps
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
...
@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
...
@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
def
get_unsloth_gradient_checkpointing_func
()
->
Callable
:
def
get_unsloth_gradient_checkpointing_func
()
->
Callable
:
class
UnslothGradientCheckpointing
(
torch
.
autograd
.
Function
):
class
UnslothGradientCheckpointing
(
torch
.
autograd
.
Function
):
r
"""
r
"""Saves VRAM by smartly offloading to RAM."""
Saves VRAM by smartly offloading to RAM.
"""
@
staticmethod
@
staticmethod
@
torch
.
cuda
.
amp
.
custom_fwd
@
torch
.
cuda
.
amp
.
custom_fwd
...
@@ -77,13 +75,14 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
...
@@ -77,13 +75,14 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
def
get_custom_gradient_checkpointing_func
(
gradient_checkpointing_func
:
Callable
)
->
Callable
:
def
get_custom_gradient_checkpointing_func
(
gradient_checkpointing_func
:
Callable
)
->
Callable
:
r
"""
r
"""Only applies gradient checkpointing to trainable layers."""
Only applies gradient checkpointing to trainable layers.
"""
@
wraps
(
gradient_checkpointing_func
,
assigned
=
WRAPPER_ASSIGNMENTS
+
(
"__self__"
,))
@
wraps
(
gradient_checkpointing_func
,
assigned
=
WRAPPER_ASSIGNMENTS
+
(
"__self__"
,))
def
custom_gradient_checkpointing_func
(
func
:
Callable
,
*
args
:
Union
[
"torch.Tensor"
,
Any
],
**
kwargs
):
def
custom_gradient_checkpointing_func
(
func
:
Callable
,
*
args
:
Union
[
"torch.Tensor"
,
Any
],
**
kwargs
):
module
:
"torch.nn.Module"
=
func
.
__self__
if
isinstance
(
func
,
partial
):
module
:
torch
.
nn
.
Module
=
func
.
func
.
__self__
else
:
module
:
torch
.
nn
.
Module
=
func
.
__self__
has_grad
=
False
has_grad
=
False
if
any
(
param
.
requires_grad
for
param
in
module
.
parameters
()):
if
any
(
param
.
requires_grad
for
param
in
module
.
parameters
()):
...
@@ -103,11 +102,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
...
@@ -103,11 +102,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def
_gradient_checkpointing_enable
(
def
_gradient_checkpointing_enable
(
self
:
"PreTrainedModel"
,
self
:
"PreTrainedModel"
,
gradient_checkpointing_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
gradient_checkpointing_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
use_unsloth_gc
:
bool
=
False
,
use_unsloth_gc
:
bool
=
False
,
)
->
None
:
)
->
None
:
r
"""
r
"""Activates gradient checkpointing for the current model.
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
"""
...
@@ -134,17 +132,18 @@ def _gradient_checkpointing_enable(
...
@@ -134,17 +132,18 @@ def _gradient_checkpointing_enable(
def
_fp32_forward_post_hook
(
def
_fp32_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
T
uple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
module
:
"torch.nn.Module"
,
args
:
t
uple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
)
->
"torch.Tensor"
:
return
output
.
to
(
torch
.
float32
)
return
output
.
to
(
torch
.
float32
)
def
prepare_model_for_training
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
)
->
None
:
def
prepare_model_for_training
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
)
->
None
:
r
"""
r
"""Prepare the model before training.
Includes:
(1) cast the layernorm in fp32
Include:
(2) make output embedding layer require grads
(1) cast the layernorm in fp32
(3) add the upcasting of the lm_head in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32.
"""
"""
if
model_args
.
upcast_layernorm
:
if
model_args
.
upcast_layernorm
:
logger
.
info_rank0
(
"Upcasting layernorm weights in float32."
)
logger
.
info_rank0
(
"Upcasting layernorm weights in float32."
)
...
...
src/llamafactory/model/model_utils/embedding.py
View file @
7ea81099
...
@@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
...
@@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
def
resize_embedding_layer
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
def
resize_embedding_layer
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""
r
"""Resize token embeddings."""
Resize token embeddings.
"""
if
is_deepspeed_zero3_enabled
():
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# type: ignore
import
deepspeed
# type: ignore
...
...
src/llamafactory/model/model_utils/kv_cache.py
0 → 100644
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
from
...extras
import
logging
logger
=
logging
.
get_logger
(
__name__
)
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
from
...hparams
import
ModelArguments
def
configure_kv_cache
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
not
is_trainable
:
setattr
(
config
,
"use_cache"
,
model_args
.
use_cache
)
if
hasattr
(
config
,
"text_config"
):
setattr
(
config
.
text_config
,
"use_cache"
,
model_args
.
use_cache
)
if
model_args
.
use_cache
:
logger
.
info_rank0
(
"KV cache is enabled for faster generation."
)
else
:
logger
.
info_rank0
(
"KV cache is disabled."
)
else
:
setattr
(
config
,
"use_cache"
,
False
)
if
hasattr
(
config
,
"text_config"
):
setattr
(
config
.
text_config
,
"use_cache"
,
False
)
logger
.
info_rank0
(
"KV cache is disabled during training."
)
src/llamafactory/model/model_utils/liger_kernel.py
View file @
7ea81099
...
@@ -27,39 +27,6 @@ if TYPE_CHECKING:
...
@@ -27,39 +27,6 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
apply_liger_kernel_to_qwen2_5_vl
(
rope
:
bool
=
True
,
cross_entropy
:
bool
=
False
,
fused_linear_cross_entropy
:
bool
=
True
,
rms_norm
:
bool
=
True
,
swiglu
:
bool
=
True
,
)
->
None
:
from
liger_kernel.transformers
import
LigerCrossEntropyLoss
,
LigerRMSNorm
,
LigerSwiGLUMLP
from
liger_kernel.transformers.model.qwen2_vl
import
lce_forward
as
qwen2_vl_lce_forward
from
liger_kernel.transformers.qwen2vl_mrope
import
liger_multimodal_rotary_pos_emb
from
transformers.models.qwen2_5_vl
import
modeling_qwen2_5_vl
def
get_dtype
(
self
:
"modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel"
):
return
self
.
dtype
modeling_qwen2_5_vl
.
Qwen2_5_VisionTransformerPretrainedModel
.
get_dtype
=
get_dtype
if
rope
:
modeling_qwen2_5_vl
.
apply_multimodal_rotary_pos_emb
=
liger_multimodal_rotary_pos_emb
if
rms_norm
:
modeling_qwen2_5_vl
.
Qwen2RMSNorm
=
LigerRMSNorm
if
cross_entropy
:
modeling_qwen2_5_vl
.
CrossEntropyLoss
=
LigerCrossEntropyLoss
if
fused_linear_cross_entropy
:
modeling_qwen2_5_vl
.
Qwen2_5_VLForConditionalGeneration
.
forward
=
qwen2_vl_lce_forward
if
swiglu
:
modeling_qwen2_5_vl
.
Qwen2MLP
=
LigerSwiGLUMLP
def
apply_liger_kernel
(
def
apply_liger_kernel
(
config
:
"PretrainedConfig"
,
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
...
@@ -74,6 +41,12 @@ def apply_liger_kernel(
...
@@ -74,6 +41,12 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma
as
apply_liger_kernel
elif
model_type
==
"gemma2"
:
elif
model_type
==
"gemma2"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma2
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma2
as
apply_liger_kernel
elif
model_type
==
"gemma3"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
elif
model_type
==
"gemma3_text"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
if
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"llama"
:
elif
model_type
==
"llama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
elif
model_type
==
"mistral"
:
elif
model_type
==
"mistral"
:
...
@@ -89,7 +62,7 @@ def apply_liger_kernel(
...
@@ -89,7 +62,7 @@ def apply_liger_kernel(
elif
model_type
==
"qwen2_vl"
:
elif
model_type
==
"qwen2_vl"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
elif
model_type
==
"qwen2_5_vl"
:
elif
model_type
==
"qwen2_5_vl"
:
apply_
liger_kernel
=
apply_liger_kernel_to_qwen2_5_vl
from
liger_kernel
.transformers
import
apply_liger_kernel_to_qwen2_5_vl
as
apply_liger_kernel
else
:
else
:
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
return
return
...
...
src/llamafactory/model/model_utils/longlora.py
View file @
7ea81099
# Copyright 202
4
EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
# Copyright 202
5
EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
#
#
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -54,14 +54,14 @@ def llama_attention_forward(
...
@@ -54,14 +54,14 @@ def llama_attention_forward(
past_key_value
:
Optional
[
"Cache"
]
=
None
,
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
position_embeddings
:
Optional
[
T
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
position_embeddings
:
Optional
[
t
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
T
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
T
uple
[
"torch.Tensor"
]]]:
)
->
t
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
t
uple
[
"torch.Tensor"
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"
torch.Tensor
"
=
self
.
q_proj
(
hidden_states
)
query_states
:
torch
.
Tensor
=
self
.
q_proj
(
hidden_states
)
key_states
:
"
torch.Tensor
"
=
self
.
k_proj
(
hidden_states
)
key_states
:
torch
.
Tensor
=
self
.
k_proj
(
hidden_states
)
value_states
:
"
torch.Tensor
"
=
self
.
v_proj
(
hidden_states
)
value_states
:
torch
.
Tensor
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
...
@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
...
@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
past_key_value
:
Optional
[
"Cache"
]
=
None
,
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
position_embeddings
:
Optional
[
T
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
position_embeddings
:
Optional
[
t
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
T
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
T
uple
[
"torch.Tensor"
]]]:
)
->
t
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
t
uple
[
"torch.Tensor"
]]]:
# LlamaFlashAttention2 attention does not support output_attentions
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions
=
False
output_attentions
=
False
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"
torch.Tensor
"
=
self
.
q_proj
(
hidden_states
)
query_states
:
torch
.
Tensor
=
self
.
q_proj
(
hidden_states
)
key_states
:
"
torch.Tensor
"
=
self
.
k_proj
(
hidden_states
)
key_states
:
torch
.
Tensor
=
self
.
k_proj
(
hidden_states
)
value_states
:
"
torch.Tensor
"
=
self
.
v_proj
(
hidden_states
)
value_states
:
torch
.
Tensor
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
...
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
...
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if
is_transformers_version_greater_than
(
"4.43.0"
):
if
is_transformers_version_greater_than
(
"4.43.0"
):
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
attn_output
:
"
torch.Tensor
"
=
_flash_attention_forward
(
attn_output
:
torch
.
Tensor
=
_flash_attention_forward
(
query_states
,
query_states
,
key_states
,
key_states
,
value_states
,
value_states
,
...
@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
...
@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
is_causal
=
self
.
is_causal
,
is_causal
=
self
.
is_causal
,
)
)
else
:
else
:
attn_output
:
"
torch.Tensor
"
=
self
.
_flash_attention_forward
(
attn_output
:
torch
.
Tensor
=
self
.
_flash_attention_forward
(
query_states
,
key_states
,
value_states
,
attention_mask
,
query_states
.
size
(
1
),
dropout
=
dropout_rate
query_states
,
key_states
,
value_states
,
attention_mask
,
query_states
.
size
(
1
),
dropout
=
dropout_rate
)
)
...
@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
...
@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
past_key_value
:
Optional
[
"Cache"
]
=
None
,
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
cache_position
:
Optional
[
"torch.LongTensor"
]
=
None
,
position_embeddings
:
Optional
[
T
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
position_embeddings
:
Optional
[
t
uple
[
"torch.Tensor"
,
"torch.Tensor"
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
T
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
T
uple
[
"torch.Tensor"
]]]:
)
->
t
uple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
Optional
[
t
uple
[
"torch.Tensor"
]]]:
if
output_attentions
:
if
output_attentions
:
transformers_logger
.
warning_once
(
transformers_logger
.
warning_once
(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
...
@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
...
@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"
torch.Tensor
"
=
self
.
q_proj
(
hidden_states
)
query_states
:
torch
.
Tensor
=
self
.
q_proj
(
hidden_states
)
key_states
:
"
torch.Tensor
"
=
self
.
k_proj
(
hidden_states
)
key_states
:
torch
.
Tensor
=
self
.
k_proj
(
hidden_states
)
value_states
:
"
torch.Tensor
"
=
self
.
v_proj
(
hidden_states
)
value_states
:
torch
.
Tensor
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
...
...
src/llamafactory/model/model_utils/misc.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
from
...extras
import
logging
from
...extras
import
logging
from
.visual
import
COMPOSITE_MODELS
from
.visual
import
COMPOSITE_MODELS
...
@@ -25,10 +25,8 @@ if TYPE_CHECKING:
...
@@ -25,10 +25,8 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
List
[
str
]:
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
list
[
str
]:
r
"""
r
"""Find all available modules to apply LoRA, GaLore or APOLLO."""
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
forbidden_modules
=
{
"lm_head"
}
forbidden_modules
=
{
"lm_head"
}
if
model_type
==
"chatglm"
:
if
model_type
==
"chatglm"
:
...
@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
...
@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
return
list
(
module_names
)
return
list
(
module_names
)
def
find_expanded_modules
(
model
:
"PreTrainedModel"
,
target_modules
:
List
[
str
],
num_layer_trainable
:
int
)
->
List
[
str
]:
def
find_expanded_modules
(
model
:
"PreTrainedModel"
,
target_modules
:
list
[
str
],
num_layer_trainable
:
int
)
->
list
[
str
]:
r
"""
r
"""Find the modules in the expanded blocks to apply lora."""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers
=
getattr
(
model
.
config
,
"num_hidden_layers"
,
None
)
num_layers
=
getattr
(
model
.
config
,
"num_hidden_layers"
,
None
)
if
not
num_layers
:
if
not
num_layers
:
raise
ValueError
(
"Model was not supported."
)
raise
ValueError
(
"Model was not supported."
)
...
...
src/llamafactory/model/model_utils/moe.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Sequence
from
typing
import
TYPE_CHECKING
import
torch
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.integrations
import
is_deepspeed_zero3_enabled
...
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
...
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
def
_set_z3_leaf_modules
(
model
:
"PreTrainedModel"
,
leaf_modules
:
Sequence
[
"torch.nn.Module"
])
->
None
:
def
_set_z3_leaf_modules
(
model
:
"PreTrainedModel"
,
leaf_modules
:
list
[
"torch.nn.Module"
])
->
None
:
check_version
(
"deepspeed>=0.13.0"
)
check_version
(
"deepspeed>=0.13.0"
)
from
deepspeed.utils
import
set_z3_leaf_modules
# type: ignore
from
deepspeed.utils
import
set_z3_leaf_modules
# type: ignore
...
@@ -34,9 +34,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
...
@@ -34,9 +34,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
def
add_z3_leaf_module
(
model
:
"PreTrainedModel"
)
->
None
:
def
add_z3_leaf_module
(
model
:
"PreTrainedModel"
)
->
None
:
r
"""
r
"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if
not
is_deepspeed_zero3_enabled
():
if
not
is_deepspeed_zero3_enabled
():
return
return
...
...
src/llamafactory/model/model_utils/packing.py
View file @
7ea81099
# Copyright 202
4
Musab Gultekin and the LlamaFactory team.
# Copyright 202
5
Musab Gultekin and the LlamaFactory team.
#
#
# This code is based on the Musab Gultekin's functionary library.
# This code is based on the Musab Gultekin's functionary library.
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
...
@@ -37,7 +37,7 @@
...
@@ -37,7 +37,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# SOFTWARE.
from
typing
import
TYPE_CHECKING
,
Tuple
from
typing
import
TYPE_CHECKING
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
...
@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
def
get_seqlens_in_batch
(
attention_mask
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
get_seqlens_in_batch
(
attention_mask
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
r
"""Get the sequnce lengths in the current batch.
Gets the sequnce lengths in the current batch.
e.g.
e.g.
```python
```python
...
@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
...
@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
bsz
=
attention_mask
.
size
(
0
)
bsz
=
attention_mask
.
size
(
0
)
dtype
,
device
=
attention_mask
.
dtype
,
attention_mask
.
device
dtype
,
device
=
attention_mask
.
dtype
,
attention_mask
.
device
max_num
=
torch
.
max
(
attention_mask
).
item
()
max_num
=
torch
.
max
(
attention_mask
).
item
()
counts
:
"
torch.Tensor
"
=
torch
.
zeros
((
bsz
,
max_num
),
dtype
=
dtype
,
device
=
device
)
counts
:
torch
.
Tensor
=
torch
.
zeros
((
bsz
,
max_num
),
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
max_num
):
for
i
in
range
(
max_num
):
counts
[:,
i
]
=
torch
.
sum
(
attention_mask
==
(
i
+
1
),
dim
=-
1
)
counts
[:,
i
]
=
torch
.
sum
(
attention_mask
==
(
i
+
1
),
dim
=-
1
)
...
@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
...
@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
return
seqlens
return
seqlens
def
get_unpad_data
(
attention_mask
:
"torch.Tensor"
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
int
]:
def
get_unpad_data
(
attention_mask
:
"torch.Tensor"
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
int
]:
r
"""
r
"""Prepare the indices and seqlens for flash attn varlen function.
Prepares the indices and seqlens for flash attn varlen function.
Returns:
Returns:
indices: indices of non-masked tokens from the flattened sequence.
indices: indices of non-masked tokens from the flattened sequence.
...
@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
...
@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
[0, 2, 5, 6, 8, 11]
[0, 2, 5, 6, 8, 11]
3
3
```
```
"""
"""
seqlens_in_batch
=
get_seqlens_in_batch
(
attention_mask
)
seqlens_in_batch
=
get_seqlens_in_batch
(
attention_mask
)
indices
=
torch
.
nonzero
(
attention_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
indices
=
torch
.
nonzero
(
attention_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
...
...
src/llamafactory/model/model_utils/quantization.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's Transformers and Optimum library.
# This code is inspired by the HuggingFace's Transformers and Optimum library.
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
import
os
import
os
import
random
import
random
from
enum
import
Enum
,
unique
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
from
datasets
import
load_dataset
from
datasets
import
load_dataset
...
@@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
...
@@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
@
unique
@
unique
class
QuantizationMethod
(
str
,
Enum
):
class
QuantizationMethod
(
str
,
Enum
):
r
"""
r
"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES
=
"bitsandbytes"
BITS_AND_BYTES
=
"bitsandbytes"
GPTQ
=
"gptq"
GPTQ
=
"gptq"
...
@@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
...
@@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
HQQ
=
"hqq"
HQQ
=
"hqq"
def
_get_quantization_dataset
(
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
)
->
List
[
Dict
[
str
,
Any
]]:
def
_get_quantization_dataset
(
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
)
->
list
[
dict
[
str
,
Any
]]:
r
"""
r
"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
if
os
.
path
.
isfile
(
model_args
.
export_quantization_dataset
):
if
os
.
path
.
isfile
(
model_args
.
export_quantization_dataset
):
data_path
=
FILEEXT2TYPE
.
get
(
model_args
.
export_quantization_dataset
.
split
(
"."
)[
-
1
],
None
)
data_path
=
FILEEXT2TYPE
.
get
(
model_args
.
export_quantization_dataset
.
split
(
"."
)[
-
1
],
None
)
data_files
=
model_args
.
export_quantization_dataset
data_files
=
model_args
.
export_quantization_dataset
...
@@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
...
@@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
raise
ValueError
(
"Cannot find satisfying example, considering decrease `export_quantization_maxlen`."
)
raise
ValueError
(
"Cannot find satisfying example, considering decrease `export_quantization_maxlen`."
)
sample_idx
=
random
.
randint
(
0
,
len
(
dataset
)
-
1
)
sample_idx
=
random
.
randint
(
0
,
len
(
dataset
)
-
1
)
sample
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
tokenizer
(
dataset
[
sample_idx
][
"text"
],
return_tensors
=
"pt"
)
sample
:
d
ict
[
str
,
torch
.
Tensor
]
=
tokenizer
(
dataset
[
sample_idx
][
"text"
],
return_tensors
=
"pt"
)
n_try
+=
1
n_try
+=
1
if
sample
[
"input_ids"
].
size
(
1
)
>
maxlen
:
if
sample
[
"input_ids"
].
size
(
1
)
>
maxlen
:
break
# TODO: fix large maxlen
break
# TODO: fix large maxlen
...
@@ -101,11 +97,9 @@ def configure_quantization(
...
@@ -101,11 +97,9 @@ def configure_quantization(
config
:
"PretrainedConfig"
,
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
init_kwargs
:
D
ict
[
str
,
Any
],
init_kwargs
:
d
ict
[
str
,
Any
],
)
->
None
:
)
->
None
:
r
"""
r
"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
if
getattr
(
config
,
"quantization_config"
,
None
):
# ptq
if
getattr
(
config
,
"quantization_config"
,
None
):
# ptq
if
model_args
.
quantization_bit
is
not
None
:
if
model_args
.
quantization_bit
is
not
None
:
logger
.
warning_rank0
(
"`quantization_bit` will not affect on the PTQ-quantized models."
)
logger
.
warning_rank0
(
"`quantization_bit` will not affect on the PTQ-quantized models."
)
...
@@ -113,7 +107,7 @@ def configure_quantization(
...
@@ -113,7 +107,7 @@ def configure_quantization(
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
raise
ValueError
(
"DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models."
)
raise
ValueError
(
"DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models."
)
quantization_config
:
D
ict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quantization_config
:
d
ict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
:
if
quant_method
==
QuantizationMethod
.
GPTQ
:
...
...
src/llamafactory/model/model_utils/rope.py
View file @
7ea81099
# Copyright 202
4
LMSYS and the LlamaFactory team.
# Copyright 202
5
LMSYS and the LlamaFactory team.
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#
# This code is inspired by the LMSYS's FastChat library.
# This code is inspired by the LMSYS's FastChat library.
...
...
src/llamafactory/model/model_utils/unsloth.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras
import
logging
from
...extras.misc
import
get_current_device
from
...extras.misc
import
get_current_device
...
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
...
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
def
_get_unsloth_kwargs
(
def
_get_unsloth_kwargs
(
config
:
"PretrainedConfig"
,
model_name_or_path
:
str
,
model_args
:
"ModelArguments"
config
:
"PretrainedConfig"
,
model_name_or_path
:
str
,
model_args
:
"ModelArguments"
)
->
D
ict
[
str
,
Any
]:
)
->
d
ict
[
str
,
Any
]:
return
{
return
{
"model_name"
:
model_name_or_path
,
"model_name"
:
model_name_or_path
,
"max_seq_length"
:
model_args
.
model_max_length
or
4096
,
"max_seq_length"
:
model_args
.
model_max_length
or
4096
,
...
@@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
...
@@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
def
load_unsloth_pretrained_model
(
def
load_unsloth_pretrained_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
Optional
[
"PreTrainedModel"
]:
)
->
Optional
[
"PreTrainedModel"
]:
r
"""
r
"""Optionally load pretrained model with unsloth. Used in training."""
Optionally loads pretrained model with unsloth. Used in training.
from
unsloth
import
FastLanguageModel
# type: ignore
"""
from
unsloth
import
FastLanguageModel
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
model_name_or_path
,
model_args
)
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
model_name_or_path
,
model_args
)
try
:
try
:
...
@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
...
@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
def
get_unsloth_peft_model
(
def
get_unsloth_peft_model
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
peft_kwargs
:
D
ict
[
str
,
Any
]
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
peft_kwargs
:
d
ict
[
str
,
Any
]
)
->
"PreTrainedModel"
:
)
->
"PreTrainedModel"
:
r
"""
r
"""Get the peft model for the pretrained model with unsloth. Used in training."""
Gets the peft model for the pretrained model with unsloth. Used in training.
from
unsloth
import
FastLanguageModel
# type: ignore
"""
from
unsloth
import
FastLanguageModel
unsloth_peft_kwargs
=
{
unsloth_peft_kwargs
=
{
"model"
:
model
,
"model"
:
model
,
...
@@ -82,10 +78,8 @@ def get_unsloth_peft_model(
...
@@ -82,10 +78,8 @@ def get_unsloth_peft_model(
def
load_unsloth_peft_model
(
def
load_unsloth_peft_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
"PreTrainedModel"
:
)
->
"PreTrainedModel"
:
r
"""
r
"""Load peft model with unsloth. Used in both training and inference."""
Loads peft model with unsloth. Used in both training and inference.
from
unsloth
import
FastLanguageModel
# type: ignore
"""
from
unsloth
import
FastLanguageModel
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
)
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
)
try
:
try
:
...
...
src/llamafactory/model/model_utils/valuehead.py
View file @
7ea81099
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
import
torch
import
torch
from
transformers.utils
import
cached_file
from
transformers.utils
import
cached_file
...
@@ -30,9 +30,8 @@ if TYPE_CHECKING:
...
@@ -30,9 +30,8 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
dict
[
str
,
torch
.
Tensor
]:
r
"""
r
"""Load value head parameters from Hugging Face Hub or local disk.
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
"""
...
...
src/llamafactory/model/model_utils/visual.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's Transformers library.
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
transformers
import
transformers
...
@@ -27,7 +27,7 @@ from ...extras import logging
...
@@ -27,7 +27,7 @@ from ...extras import logging
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
,
ProcessorMixin
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
FinetuningArguments
,
ModelArguments
from
...hparams
import
FinetuningArguments
,
ModelArguments
...
@@ -40,9 +40,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
...
@@ -40,9 +40,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
class
CompositeModel
:
class
CompositeModel
:
model_type
:
str
model_type
:
str
projector_key
:
str
projector_key
:
str
vision_model_keys
:
L
ist
[
str
]
vision_model_keys
:
l
ist
[
str
]
language_model_keys
:
L
ist
[
str
]
language_model_keys
:
l
ist
[
str
]
lora_conflict_keys
:
L
ist
[
str
]
lora_conflict_keys
:
l
ist
[
str
]
def
get_projector
(
self
,
module
:
"torch.nn.Module"
)
->
"torch.nn.Module"
:
def
get_projector
(
self
,
module
:
"torch.nn.Module"
)
->
"torch.nn.Module"
:
for
key
in
self
.
projector_key
.
split
(
"."
):
for
key
in
self
.
projector_key
.
split
(
"."
):
...
@@ -51,16 +51,26 @@ class CompositeModel:
...
@@ -51,16 +51,26 @@ class CompositeModel:
return
module
return
module
COMPOSITE_MODELS
:
D
ict
[
str
,
"CompositeModel"
]
=
{}
COMPOSITE_MODELS
:
d
ict
[
str
,
"CompositeModel"
]
=
{}
def
_register_composite_model
(
def
_register_composite_model
(
model_type
:
str
,
model_type
:
str
,
projector_key
:
Optional
[
str
]
=
None
,
projector_key
:
Optional
[
str
]
=
None
,
vision_model_keys
:
Optional
[
L
ist
[
str
]]
=
None
,
vision_model_keys
:
Optional
[
l
ist
[
str
]]
=
None
,
language_model_keys
:
Optional
[
L
ist
[
str
]]
=
None
,
language_model_keys
:
Optional
[
l
ist
[
str
]]
=
None
,
lora_conflict_keys
:
Optional
[
L
ist
[
str
]]
=
None
,
lora_conflict_keys
:
Optional
[
l
ist
[
str
]]
=
None
,
):
):
r
"""Register a new composite model.
Args:
model_type: model type
projector_key: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
"""
COMPOSITE_MODELS
[
model_type
]
=
CompositeModel
(
COMPOSITE_MODELS
[
model_type
]
=
CompositeModel
(
model_type
=
model_type
,
model_type
=
model_type
,
projector_key
=
projector_key
or
"multi_modal_projector"
,
projector_key
=
projector_key
or
"multi_modal_projector"
,
...
@@ -116,12 +126,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
...
@@ -116,12 +126,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def
autocast_projector_dtype
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
)
->
None
:
def
autocast_projector_dtype
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
)
->
None
:
r
"""
r
"""Cast projector output to half precision for fine-tuning quantized VLMs."""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def
_mm_projector_forward_post_hook
(
def
_mm_projector_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
T
uple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
module
:
"torch.nn.Module"
,
args
:
t
uple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
)
->
"torch.Tensor"
:
return
output
.
to
(
model_args
.
compute_dtype
)
return
output
.
to
(
model_args
.
compute_dtype
)
...
@@ -137,9 +145,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
...
@@ -137,9 +145,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
def
configure_visual_model
(
config
:
"PretrainedConfig"
)
->
None
:
def
configure_visual_model
(
config
:
"PretrainedConfig"
)
->
None
:
r
"""
r
"""Patch VLMs before loading them."""
Patches VLMs before loading them.
"""
if
getattr
(
config
,
"text_config"
,
None
)
and
not
getattr
(
config
,
"hidden_size"
,
None
):
if
getattr
(
config
,
"text_config"
,
None
)
and
not
getattr
(
config
,
"hidden_size"
,
None
):
# required for ds zero3 and valuehead models
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
...
@@ -149,10 +155,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
...
@@ -149,10 +155,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
def
get_forbidden_modules
(
config
:
"PretrainedConfig"
,
finetuning_args
:
"FinetuningArguments"
)
->
Set
[
str
]:
def
get_forbidden_modules
(
config
:
"PretrainedConfig"
,
finetuning_args
:
"FinetuningArguments"
)
->
set
[
str
]:
r
"""
r
"""Freeze vision tower and language model for VLM full/freeze tuning."""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
model_type
=
getattr
(
config
,
"model_type"
,
None
)
forbidden_modules
=
set
()
forbidden_modules
=
set
()
if
model_type
in
COMPOSITE_MODELS
:
if
model_type
in
COMPOSITE_MODELS
:
...
@@ -174,47 +178,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
...
@@ -174,47 +178,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return
forbidden_modules
return
forbidden_modules
def
get_image_seqlen
(
config
:
"PretrainedConfig"
)
->
int
:
r
"""
Computes the number of special tokens per image.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
==
"llava"
:
image_seqlen
=
(
config
.
vision_config
.
image_size
//
config
.
vision_config
.
patch_size
)
**
2
if
getattr
(
config
,
"vision_feature_select_strategy"
,
"default"
)
==
"full"
:
# add [CLS] token
image_seqlen
+=
1
elif
model_type
==
"paligemma"
:
image_seqlen
=
config
.
vision_config
.
num_image_tokens
else
:
image_seqlen
=
-
1
return
image_seqlen
def
get_patch_size
(
config
:
"PretrainedConfig"
,
processor
:
"ProcessorMixin"
)
->
int
:
r
"""
Computes the patch size of the vit.
"""
patch_size
=
getattr
(
config
.
vision_config
,
"patch_size"
,
getattr
(
processor
,
"patch_size"
,
-
1
))
return
patch_size
def
get_vision_feature_select_strategy
(
config
:
"PretrainedConfig"
,
processor
:
"ProcessorMixin"
)
->
int
:
r
"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy
=
getattr
(
config
,
"vision_feature_select_strategy"
,
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
)
return
vision_feature_select_strategy
def
patch_target_modules
(
def
patch_target_modules
(
model
:
"PreTrainedModel"
,
finetuning_args
:
"FinetuningArguments"
,
target_modules
:
Sequence
[
str
]
model
:
"PreTrainedModel"
,
finetuning_args
:
"FinetuningArguments"
,
target_modules
:
list
[
str
]
)
->
List
[
str
]:
)
->
list
[
str
]:
r
"""
r
"""Freeze vision tower for VLM LoRA tuning."""
Freezes vision tower for VLM LoRA tuning.
"""
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
in
COMPOSITE_MODELS
:
if
model_type
in
COMPOSITE_MODELS
:
forbidden_modules
=
get_forbidden_modules
(
model
.
config
,
finetuning_args
)
forbidden_modules
=
get_forbidden_modules
(
model
.
config
,
finetuning_args
)
...
@@ -231,6 +198,17 @@ def patch_target_modules(
...
@@ -231,6 +198,17 @@ def patch_target_modules(
return
target_modules
return
target_modules
_register_composite_model
(
model_type
=
"gemma3"
,
)
_register_composite_model
(
model_type
=
"llama4"
,
vision_model_keys
=
[
"vision_model"
],
)
_register_composite_model
(
_register_composite_model
(
model_type
=
"llava"
,
model_type
=
"llava"
,
)
)
...
@@ -285,6 +263,15 @@ _register_composite_model(
...
@@ -285,6 +263,15 @@ _register_composite_model(
)
)
_register_composite_model
(
model_type
=
"qwen2_5_omni_thinker"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
,
"audio_tower"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
_register_composite_model
(
_register_composite_model
(
model_type
=
"qwen2_vl"
,
model_type
=
"qwen2_vl"
,
projector_key
=
"visual.merger"
,
projector_key
=
"visual.merger"
,
...
...
src/llamafactory/model/patcher.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
from
peft
import
PeftModel
from
peft
import
PeftModel
...
@@ -27,19 +27,14 @@ from ..extras.packages import is_transformers_version_greater_than
...
@@ -27,19 +27,14 @@ from ..extras.packages import is_transformers_version_greater_than
from
.model_utils.attention
import
configure_attn_implementation
,
print_attn_implementation
from
.model_utils.attention
import
configure_attn_implementation
,
print_attn_implementation
from
.model_utils.checkpointing
import
prepare_model_for_training
from
.model_utils.checkpointing
import
prepare_model_for_training
from
.model_utils.embedding
import
resize_embedding_layer
from
.model_utils.embedding
import
resize_embedding_layer
from
.model_utils.kv_cache
import
configure_kv_cache
from
.model_utils.longlora
import
configure_longlora
from
.model_utils.longlora
import
configure_longlora
from
.model_utils.moe
import
add_z3_leaf_module
,
configure_moe
from
.model_utils.moe
import
add_z3_leaf_module
,
configure_moe
from
.model_utils.packing
import
configure_packing
from
.model_utils.packing
import
configure_packing
from
.model_utils.quantization
import
configure_quantization
from
.model_utils.quantization
import
configure_quantization
from
.model_utils.rope
import
configure_rope
from
.model_utils.rope
import
configure_rope
from
.model_utils.valuehead
import
prepare_valuehead_model
from
.model_utils.valuehead
import
prepare_valuehead_model
from
.model_utils.visual
import
(
from
.model_utils.visual
import
autocast_projector_dtype
,
configure_visual_model
autocast_projector_dtype
,
configure_visual_model
,
get_image_seqlen
,
get_patch_size
,
get_vision_feature_select_strategy
,
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -56,8 +51,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
...
@@ -56,8 +51,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
if
"PreTrainedTokenizerBase"
not
in
str
(
tokenizer
.
_pad
.
__func__
):
if
"PreTrainedTokenizerBase"
not
in
str
(
tokenizer
.
_pad
.
__func__
):
tokenizer
.
_pad
=
MethodType
(
PreTrainedTokenizerBase
.
_pad
,
tokenizer
)
tokenizer
.
_pad
=
MethodType
(
PreTrainedTokenizerBase
.
_pad
,
tokenizer
)
if
model_args
.
model_max_length
is
not
None
and
tokenizer
.
model_max_length
!=
model_args
.
model_max_length
:
if
model_args
.
model_max_length
is
not
None
and
tokenizer
.
model_max_length
<
model_args
.
model_max_length
:
tokenizer
.
model_max_length
=
model_args
.
model_max_length
tokenizer
.
model_max_length
=
model_args
.
model_max_length
# enlarge the tokenizer max length
if
model_args
.
new_special_tokens
is
not
None
:
if
model_args
.
new_special_tokens
is
not
None
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
num_added_tokens
=
tokenizer
.
add_special_tokens
(
...
@@ -72,28 +67,25 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
...
@@ -72,28 +67,25 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
def
patch_processor
(
def
patch_processor
(
processor
:
"ProcessorMixin"
,
processor
:
"ProcessorMixin"
,
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
)
->
None
:
)
->
None
:
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
if
getattr
(
config
,
"vision_config"
,
None
)
is
not
None
:
# visual models
setattr
(
processor
,
"image_max_pixels"
,
model_args
.
image_max_pixels
)
setattr
(
processor
,
"image_seqlen"
,
get_image_seqlen
(
config
))
setattr
(
processor
,
"image_min_pixels"
,
model_args
.
image_min_pixels
)
setattr
(
processor
,
"patch_size"
,
get_patch_size
(
config
,
processor
))
setattr
(
processor
,
"image_do_pan_and_scan"
,
model_args
.
image_do_pan_and_scan
)
setattr
(
processor
,
"image_max_pixels"
,
model_args
.
image_max_pixels
)
setattr
(
processor
,
"video_max_pixels"
,
model_args
.
video_max_pixels
)
setattr
(
processor
,
"image_min_pixels"
,
model_args
.
image_min_pixels
)
setattr
(
processor
,
"video_min_pixels"
,
model_args
.
video_min_pixels
)
setattr
(
processor
,
"video_max_pixels"
,
model_args
.
video_max_pixels
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_min_pixels"
,
model_args
.
video_min_pixels
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"audio_sampling_rate"
,
model_args
.
audio_sampling_rate
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"vision_feature_select_strategy"
,
get_vision_feature_select_strategy
(
config
,
processor
))
def
patch_config
(
def
patch_config
(
config
:
"PretrainedConfig"
,
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
init_kwargs
:
D
ict
[
str
,
Any
],
init_kwargs
:
d
ict
[
str
,
Any
],
is_trainable
:
bool
,
is_trainable
:
bool
,
)
->
None
:
)
->
None
:
if
model_args
.
compute_dtype
is
None
:
# priority: bf16 > fp16 > fp32
if
model_args
.
compute_dtype
is
None
:
# priority: bf16 > fp16 > fp32
...
@@ -112,19 +104,13 @@ def patch_config(
...
@@ -112,19 +104,13 @@ def patch_config(
configure_moe
(
config
,
model_args
,
is_trainable
)
configure_moe
(
config
,
model_args
,
is_trainable
)
configure_visual_model
(
config
)
configure_visual_model
(
config
)
configure_packing
(
model_args
,
is_trainable
)
configure_packing
(
model_args
,
is_trainable
)
configure_kv_cache
(
config
,
model_args
,
is_trainable
)
if
model_args
.
use_cache
and
not
is_trainable
:
setattr
(
config
,
"use_cache"
,
True
)
logger
.
info_rank0
(
"Using KV cache for faster generation."
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen"
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen"
:
setattr
(
config
,
"use_flash_attn"
,
model_args
.
flash_attn
==
"fa2"
)
setattr
(
config
,
"use_flash_attn"
,
model_args
.
flash_attn
==
"fa2"
)
for
dtype_name
,
dtype
in
[(
"fp16"
,
torch
.
float16
),
(
"bf16"
,
torch
.
bfloat16
),
(
"fp32"
,
torch
.
float32
)]:
for
dtype_name
,
dtype
in
[(
"fp16"
,
torch
.
float16
),
(
"bf16"
,
torch
.
bfloat16
),
(
"fp32"
,
torch
.
float32
)]:
setattr
(
config
,
dtype_name
,
model_args
.
compute_dtype
==
dtype
)
setattr
(
config
,
dtype_name
,
model_args
.
compute_dtype
==
dtype
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen2"
and
is_trainable
and
model_args
.
flash_attn
==
"fa2"
:
setattr
(
config
,
"use_cache"
,
False
)
# qwen2 does not support use_cache when using flash attn
if
getattr
(
config
,
"model_type"
,
None
)
==
"minicpmo"
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"minicpmo"
:
setattr
(
config
,
"init_audio"
,
True
)
setattr
(
config
,
"init_audio"
,
True
)
setattr
(
config
,
"init_tts"
,
False
)
setattr
(
config
,
"init_tts"
,
False
)
...
@@ -138,15 +124,13 @@ def patch_config(
...
@@ -138,15 +124,13 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs
[
"low_cpu_mem_usage"
]
=
model_args
.
low_cpu_mem_usage
and
(
not
is_deepspeed_zero3_enabled
())
init_kwargs
[
"low_cpu_mem_usage"
]
=
model_args
.
low_cpu_mem_usage
and
(
not
is_deepspeed_zero3_enabled
())
# cast data type of the model if:
# do not cast data type of the model deepspeed zero3 without qlora
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
if
not
(
is_deepspeed_zero3_enabled
()
and
model_args
.
quantization_bit
is
None
):
# 2. quantization_bit is not None (qlora)
if
(
not
is_deepspeed_zero3_enabled
()
and
not
is_fsdp_enabled
())
or
model_args
.
quantization_bit
is
not
None
:
init_kwargs
[
"torch_dtype"
]
=
model_args
.
compute_dtype
init_kwargs
[
"torch_dtype"
]
=
model_args
.
compute_dtype
if
init_kwargs
[
"low_cpu_mem_usage"
]
:
# device map requires low_cpu_mem_usage=True
if
init_kwargs
[
"low_cpu_mem_usage"
]
and
not
is_fsdp_enabled
():
# fsdp does not need device map
if
"device_map"
not
in
init_kwargs
and
model_args
.
device_map
:
if
"device_map"
not
in
init_kwargs
and
model_args
.
device_map
:
init_kwargs
[
"device_map"
]
=
model_args
.
device_map
init_kwargs
[
"device_map"
]
=
model_args
.
device_map
# device map requires low_cpu_mem_usage=True
if
init_kwargs
.
get
(
"device_map"
,
None
)
==
"auto"
:
if
init_kwargs
.
get
(
"device_map"
,
None
)
==
"auto"
:
init_kwargs
[
"offload_folder"
]
=
model_args
.
offload_folder
init_kwargs
[
"offload_folder"
]
=
model_args
.
offload_folder
...
...
src/llamafactory/train/callbacks.py
View file @
7ea81099
...
@@ -19,7 +19,7 @@ import sys
...
@@ -19,7 +19,7 @@ import sys
import
time
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
datetime
import
timedelta
from
datetime
import
timedelta
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch
import
transformers
import
transformers
...
@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
...
@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
def
fix_valuehead_checkpoint
(
def
fix_valuehead_checkpoint
(
model
:
"AutoModelForCausalLMWithValueHead"
,
output_dir
:
str
,
safe_serialization
:
bool
model
:
"AutoModelForCausalLMWithValueHead"
,
output_dir
:
str
,
safe_serialization
:
bool
)
->
None
:
)
->
None
:
r
"""
r
"""Fix the valuehead checkpoint files.
The model is already unwrapped.
The model is already unwrapped.
There are three cases:
There are three cases:
...
@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
...
@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
if
safe_serialization
:
if
safe_serialization
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
SAFE_WEIGHTS_NAME
)
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
SAFE_WEIGHTS_NAME
)
with
safe_open
(
path_to_checkpoint
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
with
safe_open
(
path_to_checkpoint
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
state_dict
:
D
ict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
state_dict
:
d
ict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
else
:
else
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
D
ict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
state_dict
:
d
ict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
os
.
remove
(
path_to_checkpoint
)
os
.
remove
(
path_to_checkpoint
)
decoder_state_dict
,
v_head_state_dict
=
{},
{}
decoder_state_dict
,
v_head_state_dict
=
{},
{}
...
@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
...
@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
class
FixValueHeadModelCallback
(
TrainerCallback
):
class
FixValueHeadModelCallback
(
TrainerCallback
):
r
"""
r
"""A callback for fixing the checkpoint for valuehead models."""
A callback for fixing the checkpoint for valuehead models.
"""
@
override
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
...
@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
...
@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
class
SaveProcessorCallback
(
TrainerCallback
):
class
SaveProcessorCallback
(
TrainerCallback
):
r
"""
r
"""A callback for saving the processor."""
A callback for saving the processor.
"""
def
__init__
(
self
,
processor
:
"ProcessorMixin"
)
->
None
:
def
__init__
(
self
,
processor
:
"ProcessorMixin"
)
->
None
:
self
.
processor
=
processor
self
.
processor
=
processor
...
@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
...
@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
class
PissaConvertCallback
(
TrainerCallback
):
class
PissaConvertCallback
(
TrainerCallback
):
r
"""
r
"""A callback for converting the PiSSA adapter to a normal one."""
A callback for converting the PiSSA adapter to a normal one.
"""
@
override
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
...
@@ -166,20 +161,17 @@ class PissaConvertCallback(TrainerCallback):
...
@@ -166,20 +161,17 @@ class PissaConvertCallback(TrainerCallback):
model
.
save_pretrained
(
pissa_backup_dir
,
safe_serialization
=
args
.
save_safetensors
)
model
.
save_pretrained
(
pissa_backup_dir
,
safe_serialization
=
args
.
save_safetensors
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
model
.
save_pretrained
(
model
.
save_pretrained
(
pissa_convert_dir
,
safe_serialization
=
args
.
save_safetensors
,
convert_pissa_to_lora
=
pissa_init_dir
pissa_convert_dir
,
)
# TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
safe_serialization
=
args
.
save_safetensors
,
path_initial_model_for_weight_conversion
=
pissa_init_dir
,
)
model
.
load_adapter
(
pissa_backup_dir
,
"default"
,
is_trainable
=
True
)
model
.
load_adapter
(
pissa_backup_dir
,
"default"
,
is_trainable
=
True
)
model
.
set_adapter
(
"default"
)
model
.
set_adapter
(
"default"
)
if
"pissa_init"
in
model
.
peft_config
.
keys
():
# backward compatibility (peft<0.12.0)
model
.
delete_adapter
(
"pissa_init"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
class
LogCallback
(
TrainerCallback
):
class
LogCallback
(
TrainerCallback
):
r
"""
r
"""A callback for logging training and evaluation status."""
A callback for logging training and evaluation status.
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
# Progress
# Progress
...
@@ -188,7 +180,7 @@ class LogCallback(TrainerCallback):
...
@@ -188,7 +180,7 @@ class LogCallback(TrainerCallback):
self
.
max_steps
=
0
self
.
max_steps
=
0
self
.
elapsed_time
=
""
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
self
.
remaining_time
=
""
self
.
thread_pool
:
Optional
[
"
ThreadPoolExecutor
"
]
=
None
self
.
thread_pool
:
Optional
[
ThreadPoolExecutor
]
=
None
# Status
# Status
self
.
aborted
=
False
self
.
aborted
=
False
self
.
do_train
=
False
self
.
do_train
=
False
...
@@ -219,7 +211,7 @@ class LogCallback(TrainerCallback):
...
@@ -219,7 +211,7 @@ class LogCallback(TrainerCallback):
self
.
elapsed_time
=
str
(
timedelta
(
seconds
=
int
(
elapsed_time
)))
self
.
elapsed_time
=
str
(
timedelta
(
seconds
=
int
(
elapsed_time
)))
self
.
remaining_time
=
str
(
timedelta
(
seconds
=
int
(
remaining_time
)))
self
.
remaining_time
=
str
(
timedelta
(
seconds
=
int
(
remaining_time
)))
def
_write_log
(
self
,
output_dir
:
str
,
logs
:
D
ict
[
str
,
Any
])
->
None
:
def
_write_log
(
self
,
output_dir
:
str
,
logs
:
d
ict
[
str
,
Any
])
->
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINER_LOG
),
"a"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINER_LOG
),
"a"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
logs
)
+
"
\n
"
)
f
.
write
(
json
.
dumps
(
logs
)
+
"
\n
"
)
...
@@ -348,9 +340,7 @@ class LogCallback(TrainerCallback):
...
@@ -348,9 +340,7 @@ class LogCallback(TrainerCallback):
class
ReporterCallback
(
TrainerCallback
):
class
ReporterCallback
(
TrainerCallback
):
r
"""
r
"""A callback for reporting training status to external logger."""
A callback for reporting training status to external logger.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
src/llamafactory/train/dpo/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
...
@@ -19,7 +19,7 @@ import warnings
...
@@ -19,7 +19,7 @@ import warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -128,16 +128,12 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -128,16 +128,12 @@ class CustomDPOTrainer(DPOTrainer):
return
super
().
_get_train_sampler
()
return
super
().
_get_train_sampler
()
@
override
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
r
"""
r
"""Replace the method of DPO Trainer with the one of the standard Trainer."""
Replaces the method of KTO Trainer with the one of the standard Trainer.
return
Trainer
.
get_batch_samples
(
self
,
*
args
,
**
kwargs
)
"""
return
Trainer
.
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
)
def
odds_ratio_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
odds_ratio_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
r
"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
log_odds
=
(
chosen_logps
-
rejected_logps
)
-
(
log_odds
=
(
chosen_logps
-
rejected_logps
)
-
(
torch
.
log1p
(
-
torch
.
exp
(
chosen_logps
))
-
torch
.
log1p
(
-
torch
.
exp
(
rejected_logps
))
torch
.
log1p
(
-
torch
.
exp
(
chosen_logps
))
-
torch
.
log1p
(
-
torch
.
exp
(
rejected_logps
))
)
)
...
@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
return
orpo_loss
return
orpo_loss
def
simpo_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
simpo_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
r
"""Compute SimPO loss for batched log probabilities of the policy model."""
Computes SimPO loss for batched log probabilities of the policy model.
"""
pi_logratios
=
chosen_logps
-
rejected_logps
pi_logratios
=
chosen_logps
-
rejected_logps
gamma_logratios
=
self
.
simpo_gamma
/
self
.
beta
gamma_logratios
=
self
.
simpo_gamma
/
self
.
beta
logits
=
pi_logratios
-
gamma_logratios
logits
=
pi_logratios
-
gamma_logratios
...
@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logps
:
"torch.Tensor"
,
policy_rejected_logps
:
"torch.Tensor"
,
reference_chosen_logps
:
Optional
[
"torch.Tensor"
],
reference_chosen_logps
:
Optional
[
"torch.Tensor"
],
reference_rejected_logps
:
Optional
[
"torch.Tensor"
],
reference_rejected_logps
:
Optional
[
"torch.Tensor"
],
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
r
"""Compute loss for preference learning."""
Computes loss for preference learning.
"""
if
not
self
.
finetuning_args
.
use_ref_model
:
if
not
self
.
finetuning_args
.
use_ref_model
:
if
self
.
loss_type
==
"orpo"
:
if
self
.
loss_type
==
"orpo"
:
losses
=
self
.
odds_ratio_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
losses
=
self
.
odds_ratio_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
...
@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
@
override
@
override
def
concatenated_forward
(
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
r
"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
Otherwise the average log probabilities.
"""
"""
if
self
.
finetuning_args
.
use_ref_model
:
if
self
.
finetuning_args
.
use_ref_model
:
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
all_logits
:
"
torch.Tensor
"
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logits
:
torch
.
Tensor
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
])
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
])
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
all_logps
=
all_logps
/
valid_length
all_logps
=
all_logps
/
valid_length
...
@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
@
override
@
override
def
compute_reference_log_probs
(
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
)
->
tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""
r
"""Compute log probabilities of the reference model."""
Computes log probabilities of the reference model.
"""
if
not
self
.
finetuning_args
.
use_ref_model
:
if
not
self
.
finetuning_args
.
use_ref_model
:
return
None
,
None
return
None
,
None
...
@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
def
get_batch_loss_metrics
(
def
get_batch_loss_metrics
(
self
,
self
,
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
batch
:
D
ict
[
str
,
"torch.Tensor"
],
batch
:
d
ict
[
str
,
"torch.Tensor"
],
train_eval
:
Literal
[
"train"
,
"eval"
]
=
"train"
,
train_eval
:
Literal
[
"train"
,
"eval"
]
=
"train"
,
)
->
Tuple
[
"torch.Tensor"
,
Dict
[
str
,
"torch.Tensor"
]]:
)
->
tuple
[
"torch.Tensor"
,
dict
[
str
,
"torch.Tensor"
]]:
r
"""
r
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics
=
{}
metrics
=
{}
(
(
policy_chosen_logps
,
policy_chosen_logps
,
...
@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
@
override
@
override
def
compute_loss
(
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""
r
"""Subclass and override to accept extra kwargs."""
Subclass and override to accept extra kwargs.
"""
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
def
log
(
self
,
logs
:
dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""
r
"""Log `logs` on the various objects watching training, including stored metrics."""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
# Add averaged stored metrics to logs
# Add averaged stored metrics to logs
...
...
src/llamafactory/train/dpo/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
...
@@ -38,7 +38,7 @@ def run_dpo(
...
@@ -38,7 +38,7 @@ def run_dpo(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/kto/trainer.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
...
@@ -19,7 +19,7 @@ import warnings
...
@@ -19,7 +19,7 @@ import warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
torch
import
torch
from
transformers
import
Trainer
from
transformers
import
Trainer
...
@@ -120,28 +120,22 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -120,28 +120,22 @@ class CustomKTOTrainer(KTOTrainer):
@
override
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
r
"""
r
"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
if
self
.
finetuning_args
.
disable_shuffling
:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
Trainer
.
_get_train_sampler
(
self
)
return
Trainer
.
_get_train_sampler
(
self
)
@
override
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
r
"""
r
"""Replace the method of KTO Trainer with the one of the standard Trainer."""
Replaces the method of KTO Trainer with the one of the standard Trainer.
return
Trainer
.
get_batch_samples
(
self
,
*
args
,
**
kwargs
)
"""
return
Trainer
.
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
)
@
override
@
override
def
forward
(
def
forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
r
"""Run forward pass and computes the log probabilities."""
Runs forward pass and computes the log probabilities.
"""
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
model_inputs
=
{
model_inputs
=
{
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
...
@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
@
override
@
override
def
concatenated_forward
(
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
D
ict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
d
ict
[
str
,
"torch.Tensor"
]
)
->
T
uple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
t
uple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
target_logits
,
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
target_logits
,
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
_
,
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
_
,
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
...
@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
@
override
@
override
def
compute_reference_log_probs
(
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
r
"""Compute log probabilities of the reference model."""
Computes log probabilities of the reference model.
"""
if
self
.
ref_model
is
None
:
if
self
.
ref_model
is
None
:
ref_model
=
model
ref_model
=
model
ref_context
=
self
.
accelerator
.
unwrap_model
(
model
).
disable_adapter
()
ref_context
=
self
.
accelerator
.
unwrap_model
(
model
).
disable_adapter
()
...
@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
def
get_batch_loss_metrics
(
def
get_batch_loss_metrics
(
self
,
self
,
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
batch
:
dict
[
str
,
"torch.Tensor"
],
)
->
Tuple
[
"torch.Tensor"
,
Dict
[
str
,
"torch.Tensor"
]]:
)
->
tuple
[
"torch.Tensor"
,
dict
[
str
,
"torch.Tensor"
]]:
r
"""
r
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics
=
{}
metrics
=
{}
(
(
policy_chosen_logps
,
policy_chosen_logps
,
...
@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
@
override
@
override
def
compute_loss
(
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""
r
"""Subclass and override to accept extra kwargs."""
Subclass and override to accept extra kwargs.
"""
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
def
log
(
self
,
logs
:
dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""
r
"""Log `logs` on the various objects watching training, including stored metrics."""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
...
@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"sum"
).
tolist
()
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"sum"
).
tolist
()
metric_dict
:
D
ict
[
str
,
float
]
=
dict
(
zip
(
key_list
,
metric_list
))
metric_dict
:
d
ict
[
str
,
float
]
=
dict
(
zip
(
key_list
,
metric_list
))
for
split
in
[
"chosen"
,
"rejected"
]:
# accumulate average metrics from sums and lengths
for
split
in
[
"chosen"
,
"rejected"
]:
# accumulate average metrics from sums and lengths
if
f
"count/
{
split
}
"
in
metric_dict
:
if
f
"count/
{
split
}
"
in
metric_dict
:
for
key
in
(
"rewards"
,
"logps"
,
"logits"
):
for
key
in
(
"rewards"
,
"logps"
,
"logits"
):
...
...
src/llamafactory/train/kto/workflow.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's TRL library.
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
...
@@ -37,7 +37,7 @@ def run_kto(
...
@@ -37,7 +37,7 @@ def run_kto(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
L
ist
[
"TrainerCallback"
]]
=
None
,
callbacks
:
Optional
[
l
ist
[
"TrainerCallback"
]]
=
None
,
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
...
...
src/llamafactory/train/ppo/ppo_utils.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
json
import
json
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
import
torch
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.integrations
import
is_deepspeed_zero3_enabled
...
@@ -31,10 +31,8 @@ if TYPE_CHECKING:
...
@@ -31,10 +31,8 @@ if TYPE_CHECKING:
from
trl
import
AutoModelForCausalLMWithValueHead
from
trl
import
AutoModelForCausalLMWithValueHead
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
List
[
str
])
->
List
[
"torch.Tensor"
]:
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
list
[
str
])
->
list
[
"torch.Tensor"
]:
r
"""
r
"""Get reward scores from the API server."""
Gets reward scores from the API server.
"""
headers
=
{
"Content-Type"
:
"application/json"
}
headers
=
{
"Content-Type"
:
"application/json"
}
payload
=
{
"model"
:
"model"
,
"messages"
:
messages
}
payload
=
{
"model"
:
"model"
,
"messages"
:
messages
}
response
=
requests
.
post
(
server_url
,
json
=
payload
,
headers
=
headers
)
response
=
requests
.
post
(
server_url
,
json
=
payload
,
headers
=
headers
)
...
@@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
...
@@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
def
replace_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
target
:
Literal
[
"default"
,
"reward"
])
->
None
:
def
replace_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
target
:
Literal
[
"default"
,
"reward"
])
->
None
:
r
"""
r
"""Replace the default/reward modules in the model. The model is already unwrapped."""
Replaces the default/reward modules in the model. The model is already unwrapped.
"""
v_head_layer
=
model
.
v_head
.
summary
v_head_layer
=
model
.
v_head
.
summary
if
is_deepspeed_zero3_enabled
():
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# type: ignore
import
deepspeed
# type: ignore
...
@@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
...
@@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_bias"
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_bias"
).
detach
().
clone
().
to
(
device
)
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
"torch.Tensor"
]:
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
dict
[
str
,
"torch.Tensor"
]:
r
"""
r
"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
layer_norm_params
=
{}
layer_norm_params
=
{}
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
param
.
data
.
dtype
==
torch
.
float32
:
if
param
.
data
.
dtype
==
torch
.
float32
:
...
@@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
...
@@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
return
layer_norm_params
return
layer_norm_params
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
Dict
[
str
,
"torch.Tensor"
]]
=
None
)
->
None
:
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
dict
[
str
,
"torch.Tensor"
]]
=
None
)
->
None
:
r
"""
r
"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
if
name
in
layernorm_params
:
if
name
in
layernorm_params
:
param
.
data
=
layernorm_params
[
name
]
param
.
data
=
layernorm_params
[
name
]
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment