Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
317a82e2
Commit
317a82e2
authored
Mar 07, 2025
by
chenych
Browse files
Add QWQ-32B
parent
37b0ad9f
Changes
255
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
37 additions
and
1028 deletions
+37
-1028
src/llamafactory/model/utils/longlora.py
src/llamafactory/model/utils/longlora.py
+0
-323
src/llamafactory/model/utils/misc.py
src/llamafactory/model/utils/misc.py
+0
-74
src/llamafactory/model/utils/mod.py
src/llamafactory/model/utils/mod.py
+0
-28
src/llamafactory/model/utils/moe.py
src/llamafactory/model/utils/moe.py
+0
-61
src/llamafactory/model/utils/quantization.py
src/llamafactory/model/utils/quantization.py
+0
-150
src/llamafactory/model/utils/rope.py
src/llamafactory/model/utils/rope.py
+0
-47
src/llamafactory/model/utils/unsloth.py
src/llamafactory/model/utils/unsloth.py
+0
-88
src/llamafactory/model/utils/valuehead.py
src/llamafactory/model/utils/valuehead.py
+0
-58
src/llamafactory/model/utils/visual.py
src/llamafactory/model/utils/visual.py
+0
-84
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+4
-4
src/llamafactory/train/dpo/__init__.py
src/llamafactory/train/dpo/__init__.py
+1
-1
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+10
-13
src/llamafactory/train/kto/__init__.py
src/llamafactory/train/kto/__init__.py
+1
-1
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+14
-12
src/llamafactory/train/ppo/__init__.py
src/llamafactory/train/ppo/__init__.py
+1
-1
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+1
-1
src/llamafactory/train/ppo/utils.py
src/llamafactory/train/ppo/utils.py
+0
-59
src/llamafactory/train/pt/__init__.py
src/llamafactory/train/pt/__init__.py
+1
-1
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+3
-21
src/llamafactory/train/rm/__init__.py
src/llamafactory/train/rm/__init__.py
+1
-1
No files found.
src/llamafactory/model/utils/longlora.py
deleted
100644 → 0
View file @
37b0ad9f
import
math
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
transformers.models.llama.modeling_llama
import
(
Cache
,
LlamaAttention
,
LlamaFlashAttention2
,
LlamaSdpaAttention
,
apply_rotary_pos_emb
,
repeat_kv
,
)
from
transformers.utils
import
logging
from
transformers.utils.versions
import
require_version
from
...extras.constants
import
SUPPORTED_CLASS_FOR_S2ATTN
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
from
...hparams
import
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def
llama_attention_forward
(
self
:
"LlamaAttention"
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"torch.Tensor"
=
self
.
q_proj
(
hidden_states
)
key_states
:
"torch.Tensor"
=
self
.
k_proj
(
hidden_states
)
value_states
:
"torch.Tensor"
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {}."
.
format
(
q_len
,
groupsz
)
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
state
=
state
.
transpose
(
1
,
2
)
# output: (bsz, seq_len, n_heads, head_dim)
state
=
torch
.
cat
(
(
state
[:,
:,
:
self
.
num_heads
//
2
],
state
[:,
:,
self
.
num_heads
//
2
:].
roll
(
-
groupsz
//
2
,
dims
=
1
)),
dim
=
2
,
)
return
state
.
reshape
(
bsz
*
num_groups
,
groupsz
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
,
key_states
,
value_states
=
shift
(
query_states
),
shift
(
key_states
),
shift
(
value_states
)
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
:,
:
groupsz
,
:
groupsz
].
repeat
(
num_groups
,
1
,
1
,
1
)
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
attention_mask
is
not
None
:
# no matter the length, we just slice it
causal_mask
=
attention_mask
[:,
:,
:,
:
key_states
.
shape
[
-
2
]]
attn_weights
=
attn_weights
+
causal_mask
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_weights
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
attention_dropout
,
training
=
self
.
training
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
# (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift back
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
torch
.
cat
(
(
attn_output
[:,
:,
:
self
.
num_heads
//
2
],
attn_output
[:,
:,
self
.
num_heads
//
2
:].
roll
(
groupsz
//
2
,
dims
=
1
),
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def
llama_flash_attention_2_forward
(
self
:
"LlamaFlashAttention2"
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions
=
False
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"torch.Tensor"
=
self
.
q_proj
(
hidden_states
)
key_states
:
"torch.Tensor"
=
self
.
k_proj
(
hidden_states
)
value_states
:
"torch.Tensor"
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
transpose
(
1
,
2
)
dropout_rate
=
self
.
attention_dropout
if
self
.
training
else
0.0
input_dtype
=
query_states
.
dtype
if
input_dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
elif
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
else
:
target_dtype
=
self
.
q_proj
.
weight
.
dtype
logger
.
warning_once
(
"The input hidden states seems to be silently casted in float32."
)
query_states
=
query_states
.
to
(
target_dtype
)
key_states
=
key_states
.
to
(
target_dtype
)
value_states
=
value_states
.
to
(
target_dtype
)
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {}."
.
format
(
q_len
,
groupsz
)
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
state
=
torch
.
cat
(
(
state
[:,
:,
:
self
.
num_heads
//
2
],
state
[:,
:,
self
.
num_heads
//
2
:].
roll
(
-
groupsz
//
2
,
dims
=
1
)),
dim
=
2
,
)
return
state
.
reshape
(
bsz
*
num_groups
,
groupsz
,
self
.
num_heads
,
self
.
head_dim
)
query_states
,
key_states
,
value_states
=
shift
(
query_states
),
shift
(
key_states
),
shift
(
value_states
)
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
:
groupsz
].
repeat
(
num_groups
,
1
)
else
:
groupsz
=
q_len
attn_output
:
torch
.
Tensor
=
self
.
_flash_attention_forward
(
query_states
,
key_states
,
value_states
,
attention_mask
,
groupsz
,
dropout
=
dropout_rate
)
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift back
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
torch
.
cat
(
(
attn_output
[:,
:,
:
self
.
num_heads
//
2
],
attn_output
[:,
:,
self
.
num_heads
//
2
:].
roll
(
groupsz
//
2
,
dims
=
1
),
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def
llama_sdpa_attention_forward
(
self
:
"LlamaSdpaAttention"
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
"Cache"
]
=
None
,
output_attentions
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
output_attentions
:
logger
.
warning_once
(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
)
return
llama_attention_forward
(
self
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
cache_position
=
cache_position
,
**
kwargs
,
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
:
"torch.Tensor"
=
self
.
q_proj
(
hidden_states
)
key_states
:
"torch.Tensor"
=
self
.
k_proj
(
hidden_states
)
value_states
:
"torch.Tensor"
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {}."
.
format
(
q_len
,
groupsz
)
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
state
=
state
.
transpose
(
1
,
2
)
# output: (bsz, seq_len, n_heads, head_dim)
state
=
torch
.
cat
(
(
state
[:,
:,
:
self
.
num_heads
//
2
],
state
[:,
:,
self
.
num_heads
//
2
:].
roll
(
-
groupsz
//
2
,
dims
=
1
)),
dim
=
2
,
)
return
state
.
reshape
(
bsz
*
num_groups
,
groupsz
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
,
key_states
,
value_states
=
shift
(
query_states
),
shift
(
key_states
),
shift
(
value_states
)
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
:,
:
groupsz
,
:
groupsz
].
repeat
(
num_groups
,
1
,
1
,
1
)
causal_mask
=
attention_mask
if
attention_mask
is
not
None
:
causal_mask
=
causal_mask
[:,
:,
:,
:
key_states
.
shape
[
-
2
]]
if
query_states
.
device
.
type
==
"cuda"
and
causal_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
causal_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal_mask
is
None
and
q_len
>
1
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift back
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
torch
.
cat
(
(
attn_output
[:,
:,
:
self
.
num_heads
//
2
],
attn_output
[:,
:,
self
.
num_heads
//
2
:].
roll
(
groupsz
//
2
,
dims
=
1
),
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
_apply_llama_patch
()
->
None
:
require_version
(
"transformers==4.40.2"
,
"To fix: pip install transformers==4.40.2"
)
LlamaAttention
.
forward
=
llama_attention_forward
LlamaFlashAttention2
.
forward
=
llama_flash_attention_2_forward
LlamaSdpaAttention
.
forward
=
llama_sdpa_attention_forward
def
configure_longlora
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
not
is_trainable
or
not
model_args
.
shift_attn
:
return
logger
=
get_logger
(
__name__
)
if
getattr
(
config
,
"model_type"
,
None
)
in
SUPPORTED_CLASS_FOR_S2ATTN
:
setattr
(
config
,
"group_size_ratio"
,
0.25
)
_apply_llama_patch
()
logger
.
info
(
"Using shift short attention with group_size_ratio=1/4."
)
else
:
logger
.
warning
(
"Current model does not support shift short attention."
)
src/llamafactory/model/utils/misc.py
deleted
100644 → 0
View file @
37b0ad9f
from
typing
import
TYPE_CHECKING
,
List
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
logger
=
get_logger
(
__name__
)
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
List
[
str
]:
r
"""
Finds all available modules to apply lora or galore.
"""
forbidden_modules
=
{
"lm_head"
}
if
model
.
config
.
model_type
==
"chatglm"
:
forbidden_modules
.
add
(
"output_layer"
)
elif
model
.
config
.
model_type
==
"internlm2"
:
forbidden_modules
.
add
(
"output"
)
elif
model
.
config
.
model_type
in
[
"llava"
,
"paligemma"
]:
forbidden_modules
.
add
(
"multi_modal_projector"
)
if
freeze_vision_tower
:
forbidden_modules
.
add
(
"vision_tower"
)
module_names
=
set
()
for
name
,
module
in
model
.
named_modules
():
if
any
(
forbidden_module
in
name
for
forbidden_module
in
forbidden_modules
):
continue
if
"Linear"
in
module
.
__class__
.
__name__
and
"Embedding"
not
in
module
.
__class__
.
__name__
:
module_names
.
add
(
name
.
split
(
"."
)[
-
1
])
logger
.
info
(
"Found linear modules: {}"
.
format
(
","
.
join
(
module_names
)))
return
list
(
module_names
)
def
find_expanded_modules
(
model
:
"PreTrainedModel"
,
target_modules
:
List
[
str
],
num_layer_trainable
:
int
)
->
List
[
str
]:
r
"""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers
=
getattr
(
model
.
config
,
"num_hidden_layers"
,
None
)
if
not
num_layers
:
raise
ValueError
(
"Model was not supported."
)
if
num_layers
%
num_layer_trainable
!=
0
:
raise
ValueError
(
"`num_layers` {} should be divisible by `num_layer_trainable` {}."
.
format
(
num_layers
,
num_layer_trainable
)
)
stride
=
num_layers
//
num_layer_trainable
trainable_layer_ids
=
range
(
stride
-
1
,
num_layers
+
stride
-
1
,
stride
)
trainable_layers
=
[
".{:d}."
.
format
(
idx
)
for
idx
in
trainable_layer_ids
]
module_names
=
[]
for
name
,
_
in
model
.
named_modules
():
if
any
(
target_module
in
name
for
target_module
in
target_modules
)
and
any
(
trainable_layer
in
name
for
trainable_layer
in
trainable_layers
):
module_names
.
append
(
name
)
logger
.
info
(
"Apply lora to layers: {}"
.
format
(
","
.
join
(
map
(
str
,
trainable_layer_ids
))))
return
module_names
def
register_autoclass
(
config
:
"PretrainedConfig"
,
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
):
if
"AutoConfig"
in
getattr
(
config
,
"auto_map"
,
{}):
config
.
__class__
.
register_for_auto_class
()
if
"AutoModelForCausalLM"
in
getattr
(
config
,
"auto_map"
,
{}):
model
.
__class__
.
register_for_auto_class
()
if
"AutoTokenizer"
in
tokenizer
.
init_kwargs
.
get
(
"auto_map"
,
{}):
tokenizer
.
__class__
.
register_for_auto_class
()
src/llamafactory/model/utils/mod.py
deleted
100644 → 0
View file @
37b0ad9f
from
typing
import
TYPE_CHECKING
from
...extras.constants
import
MOD_SUPPORTED_MODELS
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
ModelArguments
def
load_mod_pretrained_model
(
**
init_kwargs
)
->
"PreTrainedModel"
:
from
MoD
import
AutoMoDModelForCausalLM
return
AutoMoDModelForCausalLM
.
from_pretrained
(
**
init_kwargs
)
def
convert_pretrained_model_to_mod
(
model
:
"PreTrainedModel"
,
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
"PreTrainedModel"
:
from
MoD
import
apply_mod_to_hf
if
getattr
(
config
,
"model_type"
,
None
)
not
in
MOD_SUPPORTED_MODELS
:
raise
ValueError
(
"Current model is not supported by mixture-of-depth."
)
model
=
apply_mod_to_hf
(
model
)
model
=
model
.
to
(
model_args
.
compute_dtype
)
return
model
src/llamafactory/model/utils/moe.py
deleted
100644 → 0
View file @
37b0ad9f
from
typing
import
TYPE_CHECKING
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.utils.versions
import
require_version
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
ModelArguments
def
add_z3_leaf_module
(
model
:
"PreTrainedModel"
)
->
None
:
r
"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if
not
is_deepspeed_zero3_enabled
():
return
require_version
(
"deepspeed>=0.13.0"
,
"To fix: pip install deepspeed>=0.13.0"
)
from
deepspeed.utils
import
set_z3_leaf_modules
# type: ignore
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"dbrx"
:
from
transformers.models.dbrx.modeling_dbrx
import
DbrxFFN
set_z3_leaf_modules
(
model
,
[
DbrxFFN
])
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"jamba"
:
from
transformers.models.jamba.modeling_jamba
import
JambaSparseMoeBlock
set_z3_leaf_modules
(
model
,
[
JambaSparseMoeBlock
])
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"jetmoe"
:
from
transformers.models.jetmoe.modeling_jetmoe
import
JetMoeMoA
,
JetMoeMoE
set_z3_leaf_modules
(
model
,
[
JetMoeMoA
,
JetMoeMoE
])
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"mixtral"
:
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
set_z3_leaf_modules
(
model
,
[
MixtralSparseMoeBlock
])
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"qwen2moe"
:
from
transformers.models.qwen2_moe.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
set_z3_leaf_modules
(
model
,
[
Qwen2MoeSparseMoeBlock
])
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
model_args
.
moe_aux_loss_coef
is
not
None
:
if
getattr
(
config
,
"model_type"
,
None
)
in
[
"jamba"
,
"mixtral"
,
"qwen2_moe"
]:
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
elif
getattr
(
config
,
"model_type"
,
None
)
==
"deepseek"
:
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
elif
getattr
(
config
,
"model_type"
,
None
)
==
"jetmoe"
:
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
if
getattr
(
config
,
"model_type"
,
None
)
in
[
"dbrx"
,
"jamba"
,
"jetmoe"
,
"mixtral"
,
"qwen2_moe"
]:
setattr
(
config
,
"output_router_logits"
,
is_trainable
)
src/llamafactory/model/utils/quantization.py
deleted
100644 → 0
View file @
37b0ad9f
import
os
import
random
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
import
torch
from
datasets
import
load_dataset
from
transformers
import
BitsAndBytesConfig
,
GPTQConfig
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
transformers.utils.versions
import
require_version
from
...extras.constants
import
FILEEXT2TYPE
from
...extras.logging
import
get_logger
from
...extras.misc
import
get_current_device
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedTokenizer
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
@
unique
class
QuantizationMethod
(
str
,
Enum
):
r
"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES
=
"bitsandbytes"
GPTQ
=
"gptq"
AWQ
=
"awq"
AQLM
=
"aqlm"
QUANTO
=
"quanto"
EETQ
=
"eetq"
HQQ
=
"hqq"
def
_get_quantization_dataset
(
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
)
->
List
[
str
]:
r
"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if
os
.
path
.
isfile
(
model_args
.
export_quantization_dataset
):
data_path
=
FILEEXT2TYPE
.
get
(
model_args
.
export_quantization_dataset
.
split
(
"."
)[
-
1
],
None
)
data_files
=
model_args
.
export_quantization_dataset
else
:
data_path
=
model_args
.
export_quantization_dataset
data_files
=
None
dataset
=
load_dataset
(
path
=
data_path
,
data_files
=
data_files
,
split
=
"train"
,
cache_dir
=
model_args
.
cache_dir
)
maxlen
=
model_args
.
export_quantization_maxlen
samples
=
[]
for
_
in
range
(
model_args
.
export_quantization_nsamples
):
while
True
:
sample_idx
=
random
.
randint
(
0
,
len
(
dataset
)
-
1
)
sample
:
Dict
[
str
,
torch
.
Tensor
]
=
tokenizer
(
dataset
[
sample_idx
][
"text"
],
return_tensors
=
"pt"
)
if
sample
[
"input_ids"
].
size
(
1
)
>=
maxlen
:
break
# TODO: fix large maxlen
word_idx
=
random
.
randint
(
0
,
sample
[
"input_ids"
].
size
(
1
)
-
maxlen
-
1
)
input_ids
=
sample
[
"input_ids"
][:,
word_idx
:
word_idx
+
maxlen
]
samples
.
append
(
tokenizer
.
decode
(
input_ids
[
0
].
tolist
(),
skip_special_tokens
=
True
))
return
samples
def
configure_quantization
(
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
init_kwargs
:
Dict
[
str
,
Any
],
)
->
None
:
r
"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if
getattr
(
config
,
"quantization_config"
,
None
):
# ptq
if
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"DeepSpeed ZeRO-3 is incompatible with quantized models."
)
if
model_args
.
quantization_device_map
!=
"auto"
:
init_kwargs
[
"device_map"
]
=
{
""
:
get_current_device
()}
quantization_config
:
Dict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
:
require_version
(
"auto_gptq>=0.5.0"
,
"To fix: pip install auto_gptq>=0.5.0"
)
quantization_config
.
pop
(
"disable_exllama"
,
None
)
# remove deprecated args
quantization_config
[
"use_exllama"
]
=
False
# disable exllama
if
quant_method
==
QuantizationMethod
.
AWQ
:
require_version
(
"autoawq"
,
"To fix: pip install autoawq"
)
if
quant_method
==
QuantizationMethod
.
AQLM
:
require_version
(
"transformers>=4.39.0"
,
"To fix: pip install transformers>=4.39.0"
)
require_version
(
"aqlm>=1.1.0"
,
"To fix: pip install aqlm[gpu]>=1.1.0"
)
quantization_config
[
"bits"
]
=
2
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
logger
.
info
(
"Loading {}-bit {}-quantized model."
.
format
(
quant_bits
,
quant_method
.
upper
()))
elif
model_args
.
export_quantization_bit
is
not
None
:
# auto-gptq
require_version
(
"optimum>=1.16.0"
,
"To fix: pip install optimum>=1.16.0"
)
require_version
(
"auto_gptq>=0.5.0"
,
"To fix: pip install auto_gptq>=0.5.0"
)
from
accelerate.utils
import
get_max_memory
if
getattr
(
config
,
"model_type"
,
None
)
==
"chatglm"
:
raise
ValueError
(
"ChatGLM model is not supported."
)
init_kwargs
[
"quantization_config"
]
=
GPTQConfig
(
bits
=
model_args
.
export_quantization_bit
,
tokenizer
=
tokenizer
,
dataset
=
_get_quantization_dataset
(
tokenizer
,
model_args
),
)
init_kwargs
[
"device_map"
]
=
"auto"
init_kwargs
[
"max_memory"
]
=
get_max_memory
()
logger
.
info
(
"Quantizing model to {} bit."
.
format
(
model_args
.
export_quantization_bit
))
elif
model_args
.
quantization_bit
is
not
None
:
# bnb
if
model_args
.
quantization_bit
==
8
:
require_version
(
"bitsandbytes>=0.37.0"
,
"To fix: pip install bitsandbytes>=0.37.0"
)
init_kwargs
[
"quantization_config"
]
=
BitsAndBytesConfig
(
load_in_8bit
=
True
)
elif
model_args
.
quantization_bit
==
4
:
require_version
(
"bitsandbytes>=0.39.0"
,
"To fix: pip install bitsandbytes>=0.39.0"
)
init_kwargs
[
"quantization_config"
]
=
BitsAndBytesConfig
(
load_in_4bit
=
True
,
bnb_4bit_compute_dtype
=
model_args
.
compute_dtype
,
bnb_4bit_use_double_quant
=
model_args
.
double_quantization
,
bnb_4bit_quant_type
=
model_args
.
quantization_type
,
bnb_4bit_quant_storage
=
model_args
.
compute_dtype
,
# crucial for fsdp+qlora
)
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
()
or
model_args
.
quantization_device_map
==
"auto"
:
if
model_args
.
quantization_bit
!=
4
:
raise
ValueError
(
"Only 4-bit quantized model can use auto device map."
)
require_version
(
"transformers>=4.39.0"
,
"To fix: pip install transformers>=4.39.0"
)
require_version
(
"accelerate>=0.28.0"
,
"To fix: pip install accelerate>=0.28.0"
)
require_version
(
"bitsandbytes>=0.43.0"
,
"To fix: pip install bitsandbytes>=0.43.0"
)
init_kwargs
[
"torch_dtype"
]
=
model_args
.
compute_dtype
# fsdp+qlora requires same dtype
else
:
init_kwargs
[
"device_map"
]
=
{
""
:
get_current_device
()}
logger
.
info
(
"Quantizing model to {} bit."
.
format
(
model_args
.
quantization_bit
))
src/llamafactory/model/utils/rope.py
deleted
100644 → 0
View file @
37b0ad9f
import
math
from
typing
import
TYPE_CHECKING
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
def
configure_rope
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
model_args
.
rope_scaling
is
None
:
return
if
not
hasattr
(
config
,
"rope_scaling"
):
logger
.
warning
(
"Current model does not support RoPE scaling."
)
return
if
is_trainable
:
if
model_args
.
rope_scaling
==
"dynamic"
:
logger
.
warning
(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
if
current_max_length
and
model_args
.
model_max_length
>
current_max_length
:
logger
.
info
(
"Enlarge max model length from {} to {}."
.
format
(
current_max_length
,
model_args
.
model_max_length
)
)
setattr
(
config
,
"max_position_embeddings"
,
model_args
.
model_max_length
)
scaling_factor
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
current_max_length
))
else
:
logger
.
warning
(
"Input length is smaller than max length. Consider increase input length."
)
scaling_factor
=
1.0
else
:
scaling_factor
=
2.0
setattr
(
config
,
"rope_scaling"
,
{
"type"
:
model_args
.
rope_scaling
,
"factor"
:
scaling_factor
})
logger
.
info
(
"Using {} scaling strategy and setting scaling factor to {}"
.
format
(
model_args
.
rope_scaling
,
scaling_factor
)
)
src/llamafactory/model/utils/unsloth.py
deleted
100644 → 0
View file @
37b0ad9f
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
...extras.logging
import
get_logger
from
...extras.misc
import
get_current_device
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
def
_get_unsloth_kwargs
(
config
:
"PretrainedConfig"
,
model_name_or_path
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
Any
]:
return
{
"model_name"
:
model_name_or_path
,
"max_seq_length"
:
model_args
.
model_max_length
or
4096
,
"dtype"
:
model_args
.
compute_dtype
,
"load_in_4bit"
:
model_args
.
quantization_bit
==
4
,
"token"
:
model_args
.
hf_hub_token
,
"device_map"
:
{
""
:
get_current_device
()},
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
),
"fix_tokenizer"
:
False
,
"trust_remote_code"
:
True
,
"use_gradient_checkpointing"
:
"unsloth"
,
}
def
load_unsloth_pretrained_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
Optional
[
"PreTrainedModel"
]:
r
"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from
unsloth
import
FastLanguageModel
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
model_name_or_path
,
model_args
)
try
:
model
,
_
=
FastLanguageModel
.
from_pretrained
(
**
unsloth_kwargs
)
except
NotImplementedError
:
logger
.
warning
(
"Unsloth does not support model type {}."
.
format
(
getattr
(
config
,
"model_type"
,
None
)))
model
=
None
model_args
.
use_unsloth
=
False
return
model
def
get_unsloth_peft_model
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
peft_kwargs
:
Dict
[
str
,
Any
]
)
->
"PreTrainedModel"
:
r
"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from
unsloth
import
FastLanguageModel
unsloth_peft_kwargs
=
{
"model"
:
model
,
"max_seq_length"
:
model_args
.
model_max_length
,
"use_gradient_checkpointing"
:
"unsloth"
,
}
return
FastLanguageModel
.
get_peft_model
(
**
peft_kwargs
,
**
unsloth_peft_kwargs
)
def
load_unsloth_peft_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
"PreTrainedModel"
:
r
"""
Loads peft model with unsloth. Used in both training and inference.
"""
from
unsloth
import
FastLanguageModel
unsloth_kwargs
=
_get_unsloth_kwargs
(
config
,
model_args
.
adapter_name_or_path
[
0
],
model_args
)
try
:
if
not
is_trainable
:
unsloth_kwargs
[
"use_gradient_checkpointing"
]
=
False
model
,
_
=
FastLanguageModel
.
from_pretrained
(
**
unsloth_kwargs
)
except
NotImplementedError
:
raise
ValueError
(
"Unsloth does not support model type {}."
.
format
(
getattr
(
config
,
"model_type"
,
None
)))
if
not
is_trainable
:
FastLanguageModel
.
for_inference
(
model
)
return
model
src/llamafactory/model/utils/valuehead.py
deleted
100644 → 0
View file @
37b0ad9f
from
typing
import
TYPE_CHECKING
,
Dict
import
torch
from
transformers.utils
import
cached_file
from
...extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
torch
.
Tensor
]:
r
"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs
=
{
"path_or_repo_id"
:
path_or_repo_id
,
"cache_dir"
:
model_args
.
cache_dir
,
"token"
:
model_args
.
hf_hub_token
}
try
:
from
safetensors
import
safe_open
vhead_file
=
cached_file
(
filename
=
V_HEAD_SAFE_WEIGHTS_NAME
,
**
kwargs
)
with
safe_open
(
vhead_file
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
return
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
except
Exception
as
err
:
logger
.
info
(
"Failed to load {}: {}"
.
format
(
V_HEAD_SAFE_WEIGHTS_NAME
,
str
(
err
)))
try
:
vhead_file
=
cached_file
(
filename
=
V_HEAD_WEIGHTS_NAME
,
**
kwargs
)
return
torch
.
load
(
vhead_file
,
map_location
=
"cpu"
)
except
Exception
as
err
:
logger
.
info
(
"Failed to load {}: {}"
.
format
(
V_HEAD_WEIGHTS_NAME
,
str
(
err
)))
logger
.
info
(
"Provided path ({}) does not contain value head weights."
.
format
(
path_or_repo_id
))
logger
.
info
(
"Ignore these messages if you are not resuming the training of a value head model."
)
return
None
def
prepare_valuehead_model
(
model
:
"PreTrainedModel"
)
->
None
:
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"llava"
:
setattr
(
model
,
"lm_head"
,
model
.
language_model
.
get_output_embeddings
())
setattr
(
model
,
"_keys_to_ignore_on_save"
,
[
"lm_head.weight"
])
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"chatglm"
:
setattr
(
model
,
"lm_head"
,
model
.
transformer
.
output_layer
)
setattr
(
model
,
"_keys_to_ignore_on_save"
,
[
"lm_head.weight"
])
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"internlm2"
:
setattr
(
model
,
"lm_head"
,
model
.
output
)
setattr
(
model
,
"_keys_to_ignore_on_save"
,
[
"lm_head.weight"
])
src/llamafactory/model/utils/visual.py
deleted
100644 → 0
View file @
37b0ad9f
from
typing
import
TYPE_CHECKING
,
Tuple
import
torch
import
transformers.models
from
transformers.activations
import
ACT2FN
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
class
LlavaMultiModalProjectorForYiVL
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
:
"LlavaConfig"
)
->
None
:
super
().
__init__
()
self
.
config
=
config
if
config
is
None
:
return
self
.
linear_1
=
torch
.
nn
.
Linear
(
config
.
vision_config
.
hidden_size
,
config
.
text_config
.
hidden_size
,
bias
=
True
)
self
.
linear_2
=
torch
.
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
,
bias
=
True
)
self
.
linear_3
=
torch
.
nn
.
Linear
(
config
.
text_config
.
hidden_size
,
config
.
text_config
.
hidden_size
,
bias
=
True
)
self
.
linear_4
=
torch
.
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
,
bias
=
True
)
self
.
act
=
ACT2FN
[
config
.
projector_hidden_act
]
def
forward
(
self
,
image_features
:
"torch.Tensor"
)
->
"torch.Tensor"
:
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_3
(
hidden_states
)
hidden_states
=
self
.
linear_4
(
hidden_states
)
if
hidden_states
.
dtype
==
torch
.
float32
:
if
torch
.
is_autocast_enabled
():
target_dtype
=
torch
.
get_autocast_gpu_dtype
()
elif
hasattr
(
self
.
config
,
"_pre_quantization_dtype"
):
target_dtype
=
self
.
config
.
_pre_quantization_dtype
else
:
target_dtype
=
self
.
linear_1
.
weight
.
dtype
logger
.
warning_once
(
"The hidden states seems to be silently casted in float32."
)
hidden_states
=
hidden_states
.
to
(
target_dtype
)
return
hidden_states
class
LlavaMultiModalProjectorForYiVLForVLLM
(
LlavaMultiModalProjectorForYiVL
):
def
__init__
(
self
,
vision_hidden_size
:
int
,
text_hidden_size
:
int
,
projector_hidden_act
:
str
)
->
None
:
super
().
__init__
(
config
=
None
)
self
.
linear_1
=
torch
.
nn
.
Linear
(
vision_hidden_size
,
text_hidden_size
,
bias
=
True
)
self
.
linear_2
=
torch
.
nn
.
LayerNorm
(
text_hidden_size
,
bias
=
True
)
self
.
linear_3
=
torch
.
nn
.
Linear
(
text_hidden_size
,
text_hidden_size
,
bias
=
True
)
self
.
linear_4
=
torch
.
nn
.
LayerNorm
(
text_hidden_size
,
bias
=
True
)
self
.
act
=
ACT2FN
[
projector_hidden_act
]
def
autocast_projector_dtype
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
mm_projector_name
:
str
=
"multi_modal_projector"
)
->
None
:
def
_mm_projector_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
Tuple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
return
output
.
to
(
model_args
.
compute_dtype
)
if
hasattr
(
model
,
mm_projector_name
)
and
getattr
(
model
,
"quantization_method"
,
None
):
logger
.
info
(
"Casting multimodal projector outputs in {}."
.
format
(
model_args
.
compute_dtype
))
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
mm_projector_name
)
mm_projector
.
register_forward_hook
(
_mm_projector_forward_post_hook
)
def
configure_visual_model
(
config
:
"PretrainedConfig"
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"llava"
:
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
src/llamafactory/train/callbacks.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -35,7 +35,7 @@ from typing_extensions import override
from
..extras
import
logging
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.misc
import
get_peak_memory
,
use_ray
from
..extras.misc
import
get_peak_memory
,
is_env_enabled
,
use_ray
if
is_safetensors_available
():
...
...
@@ -193,7 +193,7 @@ class LogCallback(TrainerCallback):
self
.
aborted
=
False
self
.
do_train
=
False
# Web UI
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
self
.
webui_mode
=
is_env_enabled
(
"LLAMABOARD_ENABLED"
)
if
self
.
webui_mode
and
not
use_ray
():
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
self
.
logger_handler
=
logging
.
LoggerHandler
(
os
.
environ
.
get
(
"LLAMABOARD_WORKDIR"
))
...
...
@@ -299,7 +299,7 @@ class LogCallback(TrainerCallback):
logs
[
"throughput"
]
=
round
(
state
.
num_input_tokens_seen
/
(
time
.
time
()
-
self
.
start_time
),
2
)
logs
[
"total_tokens"
]
=
state
.
num_input_tokens_seen
if
os
.
environ
.
get
(
"RECORD_VRAM"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
:
if
is_env_enabled
(
"RECORD_VRAM"
)
:
vram_allocated
,
vram_reserved
=
get_peak_memory
()
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
(
1024
**
3
),
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
(
1024
**
3
),
2
)
...
...
src/llamafactory/train/dpo/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/train/dpo/trainer.py
View file @
317a82e2
...
...
@@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
,
is_transformers_version_greater_than
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
...
...
@@ -204,7 +204,11 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps
,
rejected_logps
=
all_logps
.
split
(
batch_size
,
dim
=
0
)
chosen_logits
,
rejected_logits
=
all_logits
.
split
(
batch_size
,
dim
=
0
)
chosen_length
,
_
=
valid_length
.
split
(
batch_size
,
dim
=
0
)
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
/
chosen_length
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
else
:
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
/
chosen_length
@
override
def
compute_reference_log_probs
(
...
...
@@ -278,19 +282,12 @@ class CustomDPOTrainer(DPOTrainer):
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for detail
s.
Subclass and override to accept extra kwarg
s.
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
get
(
"num_items_in_batch"
):
if
return_outputs
:
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
def
log
(
self
,
logs
:
Dict
[
str
,
float
]
,
*
args
,
**
kwargs
)
->
None
:
r
"""
Log `logs` on the various objects watching training, including stored metrics.
"""
...
...
@@ -314,4 +311,4 @@ class CustomDPOTrainer(DPOTrainer):
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
)
return
Trainer
.
log
(
self
,
logs
,
*
args
,
**
kwargs
)
src/llamafactory/train/kto/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/train/kto/trainer.py
View file @
317a82e2
...
...
@@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
,
is_transformers_version_greater_than
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
...
...
@@ -156,6 +156,15 @@ class CustomKTOTrainer(KTOTrainer):
if
"image_grid_thw"
in
batch
:
model_inputs
[
"image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
if
"aspect_ratio_ids"
in
batch
:
model_inputs
[
"aspect_ratio_ids"
]
=
batch
[
"aspect_ratio_ids"
]
if
"aspect_ratio_mask"
in
batch
:
model_inputs
[
"aspect_ratio_mask"
]
=
batch
[
"aspect_ratio_mask"
]
if
f
"
{
prefix
}
cross_attention_mask"
in
batch
:
model_inputs
[
"cross_attention_mask"
]
=
batch
[
f
"
{
prefix
}
cross_attention_mask"
]
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
f
"
{
prefix
}
labels"
])
return
logits
,
logps
,
logps
/
valid_length
...
...
@@ -256,19 +265,12 @@ class CustomKTOTrainer(KTOTrainer):
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for detail
s.
Subclass and override to accept extra kwarg
s.
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
get
(
"num_items_in_batch"
):
if
return_outputs
:
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
def
log
(
self
,
logs
:
Dict
[
str
,
float
]
,
*
args
,
**
kwargs
)
->
None
:
r
"""
Log `logs` on the various objects watching training, including stored metrics.
"""
...
...
@@ -304,4 +306,4 @@ class CustomKTOTrainer(KTOTrainer):
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
)
return
Trainer
.
log
(
self
,
logs
,
*
args
,
**
kwargs
)
src/llamafactory/train/ppo/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/train/ppo/ppo_utils.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/train/ppo/utils.py
deleted
100644 → 0
View file @
37b0ad9f
import
json
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
...extras.packages
import
is_requests_available
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
from
trl
import
AutoModelForCausalLMWithValueHead
if
is_requests_available
():
import
requests
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
List
[
str
])
->
List
[
torch
.
Tensor
]:
headers
=
{
"Content-Type"
:
"application/json"
}
payload
=
{
"model"
:
"model"
,
"messages"
:
messages
}
response
=
requests
.
post
(
server_url
,
json
=
payload
,
headers
=
headers
)
rewards
=
json
.
loads
(
response
.
text
)[
"scores"
]
return
torch
.
Tensor
(
rewards
)
def
replace_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
target
:
Literal
[
"default"
,
"reward"
])
->
None
:
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# type: ignore
params
=
[
model
.
v_head
.
summary
.
weight
,
model
.
v_head
.
summary
.
bias
]
context_maybe_zero3
=
deepspeed
.
zero
.
GatheredParameters
(
params
,
modifier_rank
=
0
)
else
:
context_maybe_zero3
=
nullcontext
()
with
context_maybe_zero3
:
if
target
==
"reward"
:
# save default head temporarily
setattr
(
model
,
"default_head_weight"
,
model
.
v_head
.
summary
.
weight
.
data
.
detach
().
clone
())
setattr
(
model
,
"default_head_bias"
,
model
.
v_head
.
summary
.
bias
.
data
.
detach
().
clone
())
model
.
pretrained_model
.
set_adapter
(
target
)
# set the LoRA adapter to be active
model
.
v_head
.
summary
.
weight
.
data
=
model
.
get_buffer
(
"{}_head_weight"
.
format
(
target
)).
detach
().
clone
()
model
.
v_head
.
summary
.
bias
.
data
=
model
.
get_buffer
(
"{}_head_bias"
.
format
(
target
)).
detach
().
clone
()
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
torch
.
Tensor
]:
layer_norm_params
=
{}
for
name
,
param
in
model
.
named_parameters
():
if
param
.
data
.
dtype
==
torch
.
float32
:
layer_norm_params
[
name
]
=
param
.
data
.
detach
().
clone
()
param
.
data
=
param
.
data
.
to
(
model
.
config
.
torch_dtype
)
return
layer_norm_params
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
)
->
None
:
for
name
,
param
in
model
.
named_parameters
():
if
name
in
layernorm_params
:
param
.
data
=
layernorm_params
[
name
]
src/llamafactory/train/pt/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/train/pt/trainer.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
transformers
import
Trainer
...
...
@@ -25,7 +25,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
transformers
import
ProcessorMixin
from
...hparams
import
FinetuningArguments
...
...
@@ -72,21 +72,3 @@ class CustomTrainer(Trainer):
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
()
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
kwargs
.
get
(
"num_items_in_batch"
)
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
if
return_outputs
:
loss
=
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
loss
=
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
src/llamafactory/train/rm/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Prev
1
…
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment