Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
581d366d
Commit
581d366d
authored
Apr 15, 2025
by
chenych
Browse files
Support GLM-4/GLM-4-0414/GLM-Z1
parent
428c5813
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
253 additions
and
70 deletions
+253
-70
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+7
-17
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+49
-0
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+7
-2
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+61
-9
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+2
-0
src/llamafactory/data/processor/supervised.py
src/llamafactory/data/processor/supervised.py
+5
-5
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+15
-1
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+51
-1
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+7
-5
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+6
-0
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+3
-3
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+6
-5
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+4
-0
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+2
-1
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+1
-1
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+6
-0
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+4
-18
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+7
-1
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+7
-1
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+3
-0
No files found.
src/llamafactory/data/collator.py
View file @
581d366d
...
...
@@ -24,7 +24,6 @@ import torch.nn.functional as F
from
transformers
import
DataCollatorForSeq2Seq
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
from
..extras.misc
import
get_current_device
from
..extras.packages
import
is_pillow_available
...
...
@@ -65,30 +64,19 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
_
,
seq_len
=
attention_mask_with_indices
.
size
()
# Move to compute device if the source is CPU.
source_device
=
attention_mask_with_indices
.
device
compute_device
=
get_current_device
()
if
source_device
.
type
==
"cpu"
else
source_device
if
compute_device
!=
source_device
:
attention_mask_with_indices
=
attention_mask_with_indices
.
to
(
compute_device
)
min_dtype
=
torch
.
finfo
(
dtype
).
min
zero_tensor
=
torch
.
tensor
(
0
,
dtype
=
dtype
,
device
=
compute_device
)
zero_tensor
=
torch
.
tensor
(
0
,
dtype
=
dtype
)
# Create a non-padding mask.
non_padding
=
(
attention_mask_with_indices
!=
0
).
unsqueeze
(
1
).
unsqueeze
(
2
)
non_padding
_mask
=
(
attention_mask_with_indices
!=
0
).
unsqueeze
(
1
).
unsqueeze
(
2
)
# Create indices for comparison.
indices
=
attention_mask_with_indices
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# [bsz, 1, 1, seq_len]
indices_t
=
attention_mask_with_indices
.
unsqueeze
(
1
).
unsqueeze
(
3
)
# [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask
=
torch
.
tril
(
torch
.
ones
((
seq_len
,
seq_len
),
dtype
=
torch
.
bool
,
device
=
compute_device
))
attention_mask_4d
=
(
indices
==
indices_t
)
&
non_padding
&
tril_mask
tril_mask
=
torch
.
tril
(
torch
.
ones
((
seq_len
,
seq_len
),
dtype
=
torch
.
bool
))
attention_mask_4d
=
(
indices
==
indices_t
)
&
non_padding
_mask
&
tril_mask
# Invert the attention mask.
attention_mask_4d
=
torch
.
where
(
attention_mask_4d
,
zero_tensor
,
min_dtype
)
# Move back to original device if needed.
if
compute_device
!=
source_device
:
attention_mask_4d
=
attention_mask_4d
.
to
(
source_device
)
return
attention_mask_4d
...
...
@@ -196,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
rope_index_kwargs
[
"second_per_grids"
]
=
mm_inputs
.
get
(
"video_second_per_grid"
)
if
getattr
(
self
.
model
.
config
,
"model_type"
,
None
)
==
"qwen2_5_omni_thinker"
:
# for qwen2omni
rope_index_kwargs
[
"use_audio_in_video"
]
=
getattr
(
self
.
processor
,
"use_audio_in_video"
,
False
)
feature_attention_mask
=
mm_inputs
.
get
(
"feature_attention_mask"
,
None
)
if
feature_attention_mask
is
not
None
:
audio_feature_lengths
=
torch
.
sum
(
...
...
@@ -309,8 +298,9 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch
[
"kl_input_ids"
]
=
kl_batch
[
"input_ids"
]
batch
[
"kl_attention_mask"
]
=
kl_batch
[
"attention_mask"
]
batch
[
"kl_labels"
]
=
kl_batch
[
"labels"
]
if
"cross_attention_mask"
in
kl_batch
:
# for mllama inputs
.
if
"cross_attention_mask"
in
kl_batch
:
# for mllama inputs
batch
[
"kl_cross_attention_mask"
]
=
kl_batch
[
"cross_attention_mask"
]
if
"token_type_ids"
in
kl_batch
:
batch
[
"kl_token_type_ids"
]
=
kl_batch
[
"token_type_ids"
]
...
...
src/llamafactory/data/data_utils.py
View file @
581d366d
...
...
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Optional
,
TypedDict
,
Union
import
fsspec
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
from
..extras
import
logging
...
...
@@ -138,3 +140,50 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
dataset_module
[
"train_dataset"
]
=
dataset
return
dataset_module
def
setup_fs
(
path
,
anon
=
False
):
"""Set up a filesystem object based on the path protocol."""
storage_options
=
{
"anon"
:
anon
}
if
anon
else
{}
if
path
.
startswith
(
"s3://"
):
fs
=
fsspec
.
filesystem
(
"s3"
,
**
storage_options
)
elif
path
.
startswith
((
"gs://"
,
"gcs://"
)):
fs
=
fsspec
.
filesystem
(
"gcs"
,
**
storage_options
)
else
:
raise
ValueError
(
f
"Unsupported protocol in path:
{
path
}
. Use 's3://' or 'gs://'"
)
return
fs
def
read_cloud_json
(
cloud_path
):
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
cloud_path : str
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
"""
try
:
# Try with anonymous access first
fs
=
setup_fs
(
cloud_path
,
anon
=
True
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
except
Exception
:
# Try again with credentials
fs
=
setup_fs
(
cloud_path
)
return
_read_json_with_fs
(
fs
,
cloud_path
,
lines
=
cloud_path
.
endswith
(
".jsonl"
))
def
_read_json_with_fs
(
fs
,
path
,
lines
=
True
):
"""Helper function to read JSON/JSONL files using fsspec."""
with
fs
.
open
(
path
,
"r"
)
as
f
:
if
lines
:
# Read JSONL (JSON Lines) format - one JSON object per line
data
=
[
json
.
loads
(
line
)
for
line
in
f
if
line
.
strip
()]
else
:
# Read regular JSON format
data
=
json
.
load
(
f
)
return
data
src/llamafactory/data/loader.py
View file @
581d366d
...
...
@@ -16,13 +16,13 @@ import os
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
numpy
as
np
from
datasets
import
load_dataset
,
load_from_disk
from
datasets
import
Dataset
,
load_dataset
,
load_from_disk
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.misc
import
check_version
,
has_tokenized_data
from
.converter
import
align_dataset
from
.data_utils
import
get_dataset_module
,
merge_dataset
,
split_dataset
from
.data_utils
import
get_dataset_module
,
merge_dataset
,
read_cloud_json
,
split_dataset
from
.parser
import
get_dataset_list
from
.processor
import
(
FeedbackDatasetProcessor
,
...
...
@@ -67,6 +67,9 @@ def _load_single_dataset(
data_name
=
dataset_attr
.
subset
data_dir
=
dataset_attr
.
folder
elif
dataset_attr
.
load_from
==
"cloud_file"
:
data_path
=
dataset_attr
.
dataset_name
elif
dataset_attr
.
load_from
==
"file"
:
data_files
=
[]
local_path
=
os
.
path
.
join
(
data_args
.
dataset_dir
,
dataset_attr
.
dataset_name
)
...
...
@@ -122,6 +125,8 @@ def _load_single_dataset(
token
=
model_args
.
om_hub_token
,
streaming
=
data_args
.
streaming
,
)
elif
dataset_attr
.
load_from
==
"cloud_file"
:
dataset
=
Dataset
.
from_list
(
read_cloud_json
(
data_path
),
split
=
dataset_attr
.
split
)
else
:
dataset
=
load_dataset
(
path
=
data_path
,
...
...
src/llamafactory/data/mm_plugin.py
View file @
581d366d
...
...
@@ -466,6 +466,41 @@ class Gemma3Plugin(BasePlugin):
return
mm_inputs
@
dataclass
class
KimiVLPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
,
images
,
videos
,
audios
,
processor
):
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_hws
=
mm_inputs
.
get
(
"image_grid_hws"
,
[])
num_image_tokens
=
0
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
merge_length
=
math
.
prod
(
image_processor
.
merge_kernel_size
)
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
image_grid_hws
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_seqlen
=
image_grid_hws
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|media_start|>image<|media_content|>
{
self
.
image_token
*
image_seqlen
}
<|media_end|>"
,
1
,
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
dataclass
class
Llama4Plugin
(
BasePlugin
):
@
override
...
...
@@ -493,8 +528,8 @@ class Llama4Plugin(BasePlugin):
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
placeholder_count
=
content
.
count
(
IMAGE_PLACEHOLDER
)
if
self
.
expand_mm_tokens
:
placeholder_count
=
content
.
count
(
IMAGE_PLACEHOLDER
)
prompt_splits
=
content
.
split
(
IMAGE_PLACEHOLDER
)
new_content
=
[]
for
local_image_index
,
split_part
in
enumerate
(
prompt_splits
):
...
...
@@ -507,6 +542,8 @@ class Llama4Plugin(BasePlugin):
new_content
.
append
(
tokens_for_this_image
)
content
=
""
.
join
(
new_content
)
else
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
message
[
"content"
]
=
content
...
...
@@ -1376,6 +1413,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else
:
mm_inputs
=
{}
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
num_audio_tokens
,
num_image_tokens
,
num_video_tokens
=
0
,
0
,
0
use_audio_in_video
=
getattr
(
processor
,
"use_audio_in_video"
,
False
)
...
...
@@ -1396,16 +1434,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if
audio_lengths
is
None
:
raise
ValueError
(
"audio_lengths should exist when use_audio_in_video is `True`."
)
if
not
mm_inputs
.
get
(
"video_grid_thw"
,
None
):
if
mm_inputs
.
get
(
"video_grid_thw"
,
None
)
is
None
:
raise
ValueError
(
"video_grid_thw should exist when use_audio_in_video is `True`."
)
positions_list
=
[]
for
i
,
message
in
enumerate
(
messages
)
:
# get multimodal index when use_audio
for
message
in
messages
:
# get multimodal index when use_audio
positions
=
[]
for
special_token
in
[
self
.
audio_token
,
self
.
image_token
,
self
.
video_token
]:
start
=
0
while
True
:
pos
=
message
[
i
].
find
(
special_token
,
start
)
pos
=
message
[
"content"
].
find
(
special_token
,
start
)
if
pos
==
-
1
:
break
positions
.
append
((
pos
,
special_token
))
...
...
@@ -1417,6 +1455,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
content
=
message
[
"content"
]
# separate with audio-video
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
image_grid_thw
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_token_replace_length
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
...
...
@@ -1427,6 +1468,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if
not
use_audio_in_video
:
while
AUDIO_PLACEHOLDER
in
content
:
if
num_audio_tokens
>=
len
(
audio_lengths
):
raise
ValueError
(
f
"`len(audios)` is less than the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
audio_token_replace_length
=
audio_lengths
[
num_audio_tokens
]
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
...
...
@@ -1437,6 +1481,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
# TODO handle video_input and use_audio_in_video
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
video_grid_thw
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
video_replace_length
=
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
video_token
*
video_replace_length
}
<|vision_eos|>"
,
1
...
...
@@ -1445,14 +1492,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else
:
# if use the audio of video # deal video token and audio token togather
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
video_grid_thw
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
audio_t_index
=
torch
.
arange
(
audio_lengths
[
num_audio_tokens
])
video_t_index
=
(
torch
.
arange
(
video_grid_thw
[
num_video_tokens
][
0
])
.
view
(
-
1
,
1
,
1
)
.
expand
(
-
1
,
video_grid_thw
[
num_video_tokens
][
1
]
//
self
.
image_processor
.
merge_size
,
video_grid_thw
[
num_video_tokens
][
2
]
//
self
.
image_processor
.
merge_size
,
video_grid_thw
[
num_video_tokens
][
1
]
//
image_processor
.
merge_size
,
video_grid_thw
[
num_video_tokens
][
2
]
//
image_processor
.
merge_size
,
)
.
flatten
()
*
mm_inputs
[
"video_second_per_grid"
][
num_video_tokens
]
...
...
@@ -1460,18 +1510,19 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
).
long
()
t_ntoken_per_chunk
=
50
# FIXME hardcode: [25 * 2]
video_chunk_indices
=
processor
.
get_chunked_index
(
video_t_index
,
t_ntoken_per_chunk
)
audio_chunk_indices
=
self
.
get_chunked_index
(
audio_t_index
,
t_ntoken_per_chunk
)
audio_chunk_indices
=
processor
.
get_chunked_index
(
audio_t_index
,
t_ntoken_per_chunk
)
placeholder_string
=
""
placeholder_string
+=
"<|vision_bos|>"
+
"<|audio_bos|>"
for
j
in
range
(
max
(
len
(
video_chunk_indices
),
len
(
audio_chunk_indices
))):
video_chunk_index
=
video_chunk_indices
[
j
]
if
j
<
len
(
video_chunk_indices
)
else
None
audio_chunk_index
=
audio_chunk_indices
[
j
]
if
j
<
len
(
audio_chunk_indices
)
else
None
placeholder_string
=
"<|vision_bos|>"
+
"<|audio_bos|>"
if
video_chunk_index
is
not
None
:
placeholder_string
+=
self
.
video_token
*
(
video_chunk_index
[
1
]
-
video_chunk_index
[
0
])
if
audio_chunk_index
is
not
None
:
placeholder_string
+=
self
.
audio_token
*
(
audio_chunk_index
[
1
]
-
audio_chunk_index
[
0
])
placeholder_string
+=
"<|audio_eos|>"
+
"<|vision_eos|>"
placeholder_string
+=
"<|audio_eos|>"
+
"<|vision_eos|>"
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
placeholder_string
,
1
)
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
""
,
1
)
num_audio_tokens
+=
1
...
...
@@ -1552,6 +1603,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS
=
{
"base"
:
BasePlugin
,
"gemma3"
:
Gemma3Plugin
,
"kimi_vl"
:
KimiVLPlugin
,
"llama4"
:
Llama4Plugin
,
"llava"
:
LlavaPlugin
,
"llava_next"
:
LlavaNextPlugin
,
...
...
src/llamafactory/data/parser.py
View file @
581d366d
...
...
@@ -141,6 +141,8 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
dataset_name
=
dataset_info
[
name
][
"hf_hub_url"
])
elif
"script_url"
in
dataset_info
[
name
]:
dataset_attr
=
DatasetAttr
(
"script"
,
dataset_name
=
dataset_info
[
name
][
"script_url"
])
elif
"cloud_file_name"
in
dataset_info
[
name
]:
dataset_attr
=
DatasetAttr
(
"cloud_file"
,
dataset_name
=
dataset_info
[
name
][
"cloud_file_name"
])
else
:
dataset_attr
=
DatasetAttr
(
"file"
,
dataset_name
=
dataset_info
[
name
][
"file_name"
])
...
...
src/llamafactory/data/processor/supervised.py
View file @
581d366d
...
...
@@ -164,28 +164,28 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs
=
defaultdict
(
list
)
knapsacks
=
greedy_knapsack
(
lengths
,
self
.
data_args
.
cutoff_len
)
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_images
,
packed_videos
,
packed_audios
,
packed_position_ids
=
[],
[],
[],
[]
packed_input_ids
,
packed_attention_masks
,
packed_position_ids
,
packed_labels
=
[],
[],
[]
,
[]
packed_images
,
packed_videos
,
packed_audios
=
[],
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
packed_position_ids
+=
list
(
range
(
len
(
batch_input_ids
[
index
])))
# NOTE: pad_to_multiple_of ignore this
packed_labels
+=
batch_labels
[
index
]
packed_images
+=
batch_images
[
index
]
packed_videos
+=
batch_videos
[
index
]
packed_audios
+=
batch_audios
[
index
]
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
packed_position_ids
+=
list
(
range
(
len
(
batch_input_ids
[
index
])))
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
if
len
(
packed_input_ids
)
<
self
.
data_args
.
cutoff_len
+
1
:
# avoid flash_attn drops attn mask
pad_length
=
self
.
data_args
.
cutoff_len
-
len
(
packed_input_ids
)
+
1
packed_input_ids
+=
[
self
.
tokenizer
.
pad_token_id
]
*
pad_length
packed_position_ids
+=
[
0
]
*
pad_length
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
packed_position_ids
+=
[
0
]
*
pad_length
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
...
...
@@ -194,10 +194,10 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs
[
"input_ids"
].
append
(
packed_input_ids
)
model_inputs
[
"attention_mask"
].
append
(
packed_attention_masks
)
model_inputs
[
"position_ids"
].
append
(
packed_position_ids
)
model_inputs
[
"labels"
].
append
(
packed_labels
)
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
model_inputs
[
"audios"
].
append
(
packed_audios
or
None
)
model_inputs
[
"position_ids"
].
append
(
packed_position_ids
or
None
)
return
model_inputs
src/llamafactory/data/template.py
View file @
581d366d
...
...
@@ -923,6 +923,20 @@ register_template(
)
register_template
(
name
=
"kimi_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_system|>system<|im_middle|>{{content}}<|im_end|>"
]),
default_system
=
"You are a helpful assistant"
,
stop_words
=
[
"<|im_end|>"
],
thought_words
=
(
"◁think▷"
,
"◁/think▷"
),
mm_plugin
=
get_mm_plugin
(
"kimi_vl"
,
image_token
=
"<|media_pad|>"
),
)
register_template
(
name
=
"llama2"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
...
...
@@ -1370,7 +1384,7 @@ register_template(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
default_system
=
"You are
Qwen, created by Alibaba Cloud. You are
a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
)
...
...
src/llamafactory/extras/constants.py
View file @
581d366d
...
...
@@ -14,7 +14,7 @@
import
os
from
collections
import
OrderedDict
,
defaultdict
from
enum
import
Enum
from
enum
import
Enum
,
unique
from
typing
import
Optional
from
peft.utils
import
SAFETENSORS_WEIGHTS_NAME
as
SAFE_ADAPTER_WEIGHTS_NAME
...
...
@@ -115,6 +115,19 @@ class DownloadSource(str, Enum):
OPENMIND
=
"om"
@
unique
class
QuantizationMethod
(
str
,
Enum
):
r
"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BNB
=
"bnb"
GPTQ
=
"gptq"
AWQ
=
"awq"
AQLM
=
"aqlm"
QUANTO
=
"quanto"
EETQ
=
"eetq"
HQQ
=
"hqq"
class
RopeScaling
(
str
,
Enum
):
LINEAR
=
"linear"
DYNAMIC
=
"dynamic"
...
...
@@ -133,6 +146,7 @@ def register_model_group(
any
(
suffix
in
name
for
suffix
in
(
"-Chat"
,
"-Distill"
,
"-Instruct"
))
or
multimodal
):
DEFAULT_TEMPLATE
[
name
]
=
template
if
multimodal
:
MULTIMODAL_SUPPORTED_MODELS
.
add
(
name
)
...
...
@@ -711,6 +725,26 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"THUDM/glm-4-9b-chat-1m"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/glm-4-9b-chat-1m"
,
},
"GLM-4-9B-0414-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-9B-0414"
,
},
"GLM-4-32B-0414"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-32B-Base-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-Base-0414"
,
},
"GLM-4-32B-0414-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-4-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-0414"
,
},
"GLM-Z1-9B-0414-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-Z1-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-9B-0414"
,
},
"GLM-Z1-32B-0414-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/GLM-Z1-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-32B-0414"
,
},
},
template
=
"glm4"
,
)
...
...
@@ -941,6 +975,22 @@ register_model_group(
)
register_model_group
(
models
=
{
"Kimi-VL-A3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"moonshotai/Kimi-VL-A3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"moonshotai/Kimi-VL-A3B-Instruct"
,
},
"Kimi-VL-A3B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"moonshotai/Kimi-VL-A3B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"moonshotai/Kimi-VL-A3B-Thinking"
,
},
},
template
=
"kimi_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"LingoWhale-8B"
:
{
...
...
src/llamafactory/extras/misc.py
View file @
581d366d
...
...
@@ -89,10 +89,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.41.2,<=4.51.
0
,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"datasets>=2.16.0,<=3.
4.1
"
)
check_version
(
"accelerate>=0.34.0,<=1.
5.2
"
)
check_version
(
"peft>=0.14.0,<=0.15.
0
"
)
check_version
(
"transformers>=4.41.2,<=4.51.
3
,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0"
)
check_version
(
"datasets>=2.16.0,<=3.
5.0
"
)
check_version
(
"accelerate>=0.34.0,<=1.
6.0
"
)
check_version
(
"peft>=0.14.0,<=0.15.
1
"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
if
is_transformers_version_greater_than
(
"4.46.0"
)
and
not
is_transformers_version_greater_than
(
"4.48.1"
):
logger
.
warning_rank0_once
(
"There are known bugs in transformers v4.46.0-v4.48.0, please use other versions."
)
...
...
@@ -177,6 +177,8 @@ def get_peak_memory() -> tuple[int, int]:
r
"""Get the peak memory usage for the current device (in Bytes)."""
if
is_torch_npu_available
():
return
torch
.
npu
.
max_memory_allocated
(),
torch
.
npu
.
max_memory_reserved
()
elif
is_torch_xpu_available
():
return
torch
.
xpu
.
max_memory_allocated
(),
torch
.
xpu
.
max_memory_reserved
()
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
max_memory_allocated
(),
torch
.
cuda
.
max_memory_reserved
()
else
:
...
...
@@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def
is_gpu_or_npu_available
()
->
bool
:
r
"""Check if the GPU or NPU is available."""
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
or
is_torch_xpu_available
()
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
...
...
src/llamafactory/hparams/data_args.py
View file @
581d366d
...
...
@@ -160,5 +160,11 @@ class DataArguments:
if
self
.
mask_history
and
self
.
train_on_prompt
:
raise
ValueError
(
"`mask_history` is incompatible with `train_on_prompt`."
)
if
self
.
neat_packing
:
self
.
packing
=
True
if
self
.
packing
:
self
.
cutoff_len
-=
1
# avoid pad_to_multiple_of, needs improve
def
to_dict
(
self
)
->
dict
[
str
,
Any
]:
return
asdict
(
self
)
src/llamafactory/hparams/model_args.py
View file @
581d366d
...
...
@@ -23,7 +23,7 @@ import torch
from
transformers.training_args
import
_convert_str_dict
from
typing_extensions
import
Self
from
..extras.constants
import
AttentionFunction
,
EngineName
,
RopeScaling
from
..extras.constants
import
AttentionFunction
,
EngineName
,
QuantizationMethod
,
RopeScaling
@
dataclass
...
...
@@ -184,8 +184,8 @@ class BaseModelArguments:
class
QuantizationArguments
:
r
"""Arguments pertaining to the quantization method."""
quantization_method
:
Literal
[
"bitsandbytes"
,
"hqq"
,
"eetq"
]
=
field
(
default
=
"bitsandbytes"
,
quantization_method
:
QuantizationMethod
=
field
(
default
=
QuantizationMethod
.
BNB
,
metadata
=
{
"help"
:
"Quantization method to use for on-the-fly quantization."
},
)
quantization_bit
:
Optional
[
int
]
=
field
(
...
...
src/llamafactory/hparams/parser.py
View file @
581d366d
...
...
@@ -135,7 +135,7 @@ def _check_extra_dependencies(
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
check_version
(
"vllm>=0.4.3,<=0.8.
2
"
)
check_version
(
"vllm>=0.4.3,<=0.8.
4
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
check_version
(
"sglang>=0.4.4"
)
...
...
@@ -285,10 +285,6 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if
model_args
.
use_unsloth
and
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"Unsloth is incompatible with DeepSpeed ZeRO-3."
)
if
data_args
.
neat_packing
and
not
data_args
.
packing
:
logger
.
warning_rank0
(
"`neat_packing` requires `packing` is True. Change `packing` to True."
)
data_args
.
packing
=
True
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
,
training_args
)
...
...
@@ -394,8 +390,10 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
def
get_infer_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
],
list
[
str
]]]
=
None
)
->
_INFER_CLS
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
_parse_infer_args
(
args
)
# Setup logging
_set_transformers_logging
()
# Check arguments
if
model_args
.
infer_backend
==
"vllm"
:
if
finetuning_args
.
stage
!=
"sft"
:
raise
ValueError
(
"vLLM engine only supports auto-regressive models."
)
...
...
@@ -412,6 +410,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
)
# Post-process model arguments
if
model_args
.
export_dir
is
not
None
and
model_args
.
export_device
==
"cpu"
:
model_args
.
device_map
=
{
""
:
torch
.
device
(
"cpu"
)}
if
data_args
.
cutoff_len
!=
DataArguments
().
cutoff_len
:
# override cutoff_len if it is not default
...
...
@@ -425,8 +424,10 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
def
get_eval_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
],
list
[
str
]]]
=
None
)
->
_EVAL_CLS
:
model_args
,
data_args
,
eval_args
,
finetuning_args
=
_parse_eval_args
(
args
)
# Setup logging
_set_transformers_logging
()
# Check arguments
if
model_args
.
infer_backend
==
"vllm"
:
raise
ValueError
(
"vLLM backend is only available for API, CLI and Web."
)
...
...
src/llamafactory/hparams/training_args.py
View file @
581d366d
...
...
@@ -46,6 +46,10 @@ class RayArguments:
default
=
"PACK"
,
metadata
=
{
"help"
:
"The placement strategy for Ray training. Default is PACK."
},
)
ray_init_kwargs
:
Optional
[
dict
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The arguments to pass to ray.init for Ray training. Default is None."
},
)
def
__post_init__
(
self
):
self
.
use_ray
=
use_ray
()
...
...
src/llamafactory/model/loader.py
View file @
581d366d
...
...
@@ -97,12 +97,13 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
model_name_or_path
,
**
init_kwargs
)
patch_processor
(
processor
,
tokenizer
,
model_args
)
except
Exception
as
e
:
logger
.
debug
(
f
"
Processor was not found
:
{
e
}
."
)
logger
.
debug
(
f
"
Failed to load processor
:
{
e
}
."
)
processor
=
None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if
processor
is
not
None
and
"Processor"
not
in
processor
.
__class__
.
__name__
:
logger
.
debug
(
"The loaded processor is not an instance of Processor. Dropping it."
)
processor
=
None
return
{
"tokenizer"
:
tokenizer
,
"processor"
:
processor
}
...
...
src/llamafactory/model/model_utils/liger_kernel.py
View file @
581d366d
...
...
@@ -45,7 +45,7 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3
as
apply_liger_kernel
elif
model_type
==
"gemma3_text"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
if
model_type
==
"paligemma"
:
el
if
model_type
==
"paligemma"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_paligemma
as
apply_liger_kernel
elif
model_type
==
"llama"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_llama
as
apply_liger_kernel
...
...
src/llamafactory/model/model_utils/moe.py
View file @
581d366d
...
...
@@ -54,6 +54,12 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules
(
model
,
[
JetMoeMoA
,
JetMoeMoE
])
if
model_type
in
[
"kimi_vl"
,
"deepseek_v3"
]:
check_version
(
"transformers>=4.51.1"
)
from
transformers.models.deepseek_v3.modeling_deepseek_v3
import
DeepseekV3MoE
_set_z3_leaf_modules
(
model
,
[
DeepseekV3MoE
])
if
model_type
==
"mixtral"
:
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
...
...
src/llamafactory/model/model_utils/quantization.py
View file @
581d366d
...
...
@@ -18,7 +18,6 @@
import
os
import
random
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Any
import
torch
...
...
@@ -28,7 +27,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
...extras
import
logging
from
...extras.constants
import
FILEEXT2TYPE
from
...extras.constants
import
FILEEXT2TYPE
,
QuantizationMethod
from
...extras.misc
import
check_version
,
get_current_device
...
...
@@ -41,19 +40,6 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
@
unique
class
QuantizationMethod
(
str
,
Enum
):
r
"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BITS_AND_BYTES
=
"bitsandbytes"
GPTQ
=
"gptq"
AWQ
=
"awq"
AQLM
=
"aqlm"
QUANTO
=
"quanto"
EETQ
=
"eetq"
HQQ
=
"hqq"
def
_get_quantization_dataset
(
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
)
->
list
[
dict
[
str
,
Any
]]:
r
"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
if
os
.
path
.
isfile
(
model_args
.
export_quantization_dataset
):
...
...
@@ -145,7 +131,7 @@ def configure_quantization(
logger
.
info_rank0
(
f
"Quantizing model to
{
model_args
.
export_quantization_bit
}
bit with AutoGPTQ."
)
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
if
model_args
.
quantization_method
==
QuantizationMethod
.
B
ITS_AND_BYTES
.
value
:
if
model_args
.
quantization_method
==
QuantizationMethod
.
B
NB
:
if
model_args
.
quantization_bit
==
8
:
check_version
(
"bitsandbytes>=0.37.0"
,
mandatory
=
True
)
init_kwargs
[
"quantization_config"
]
=
BitsAndBytesConfig
(
load_in_8bit
=
True
)
...
...
@@ -173,7 +159,7 @@ def configure_quantization(
init_kwargs
[
"device_map"
]
=
{
""
:
get_current_device
()}
# change auto device map for inference
logger
.
info_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with bitsandbytes."
)
elif
model_args
.
quantization_method
==
QuantizationMethod
.
HQQ
.
value
:
elif
model_args
.
quantization_method
==
QuantizationMethod
.
HQQ
:
if
model_args
.
quantization_bit
not
in
[
8
,
6
,
5
,
4
,
3
,
2
,
1
]:
raise
ValueError
(
"HQQ only accepts 1/2/3/4/5/6/8-bit quantization."
)
...
...
@@ -185,7 +171,7 @@ def configure_quantization(
nbits
=
model_args
.
quantization_bit
,
quant_zero
=
False
,
quant_scale
=
False
,
axis
=
0
)
# use ATEN kernel (axis=0) for performance
logger
.
info_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with HQQ."
)
elif
model_args
.
quantization_method
==
QuantizationMethod
.
EETQ
.
value
:
elif
model_args
.
quantization_method
==
QuantizationMethod
.
EETQ
:
if
model_args
.
quantization_bit
!=
8
:
raise
ValueError
(
"EETQ only accepts 8-bit quantization."
)
...
...
src/llamafactory/model/patcher.py
View file @
581d366d
...
...
@@ -79,6 +79,7 @@ def patch_processor(
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"audio_sampling_rate"
,
model_args
.
audio_sampling_rate
)
setattr
(
processor
,
"use_audio_in_video"
,
model_args
.
use_audio_in_video
)
def
patch_config
(
...
...
@@ -95,7 +96,8 @@ def patch_config(
model_args
.
compute_dtype
=
infer_optim_dtype
(
model_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
))
if
is_torch_npu_available
():
torch
.
npu
.
set_compile_mode
(
jit_compile
=
is_env_enabled
(
"JIT_COMPILE"
))
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch
.
npu
.
set_compile_mode
(
jit_compile
=
is_env_enabled
(
"NPU_JIT_COMPILE"
))
configure_attn_implementation
(
config
,
model_args
,
is_trainable
)
configure_rope
(
config
,
model_args
,
is_trainable
)
...
...
@@ -115,6 +117,10 @@ def patch_config(
setattr
(
config
,
"init_audio"
,
True
)
setattr
(
config
,
"init_tts"
,
False
)
# replace the top-k gating method
if
getattr
(
config
,
"model_type"
,
None
)
==
"kimi_vl"
and
is_trainable
:
setattr
(
config
.
text_config
,
"topk_method"
,
"greedy"
)
if
"LlavaLlamaForCausalLM"
in
getattr
(
config
,
"architectures"
,
[]):
raise
ValueError
(
"Please download llava models with hf-compatible format: https://huggingface.co/llava-hf"
)
...
...
src/llamafactory/train/dpo/workflow.py
View file @
581d366d
...
...
@@ -91,7 +91,13 @@ def run_dpo(
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
plot_loss
(
training_args
.
output_dir
,
keys
=
[
"loss"
,
"eval_loss"
,
"rewards/accuracies"
])
keys
=
[
"loss"
,
"rewards/accuracies"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
[
f
"eval_
{
key
}
_loss"
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()]
else
:
keys
+=
[
"eval_loss"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
# Evaluation
if
training_args
.
do_eval
:
...
...
src/llamafactory/train/kto/trainer.py
View file @
581d366d
...
...
@@ -147,6 +147,9 @@ class CustomKTOTrainer(KTOTrainer):
if
"pixel_values"
in
batch
:
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
if
"image_sizes"
in
batch
:
model_inputs
[
"image_sizes"
]
=
batch
[
"image_sizes"
]
if
"image_grid_thw"
in
batch
:
model_inputs
[
"image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment