Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1526 additions
and
0 deletions
+1526
-0
src/llamafactory/v1/core/trainer_utils/data_collator.py
src/llamafactory/v1/core/trainer_utils/data_collator.py
+119
-0
src/llamafactory/v1/core/trainer_utils/data_loader.py
src/llamafactory/v1/core/trainer_utils/data_loader.py
+277
-0
src/llamafactory/v1/core/trainer_utils/lr_scheduler.py
src/llamafactory/v1/core/trainer_utils/lr_scheduler.py
+0
-0
src/llamafactory/v1/core/utils/__init__.py
src/llamafactory/v1/core/utils/__init__.py
+0
-0
src/llamafactory/v1/core/utils/batching.py
src/llamafactory/v1/core/utils/batching.py
+251
-0
src/llamafactory/v1/core/utils/callback.py
src/llamafactory/v1/core/utils/callback.py
+0
-0
src/llamafactory/v1/core/utils/inference_engine.py
src/llamafactory/v1/core/utils/inference_engine.py
+121
-0
src/llamafactory/v1/core/utils/rendering.py
src/llamafactory/v1/core/utils/rendering.py
+176
-0
src/llamafactory/v1/launcher.py
src/llamafactory/v1/launcher.py
+66
-0
src/llamafactory/v1/plugins/__init__.py
src/llamafactory/v1/plugins/__init__.py
+0
-0
src/llamafactory/v1/plugins/data_plugins/__init__.py
src/llamafactory/v1/plugins/data_plugins/__init__.py
+0
-0
src/llamafactory/v1/plugins/data_plugins/converter.py
src/llamafactory/v1/plugins/data_plugins/converter.py
+182
-0
src/llamafactory/v1/plugins/data_plugins/loader.py
src/llamafactory/v1/plugins/data_plugins/loader.py
+114
-0
src/llamafactory/v1/plugins/data_plugins/template.py
src/llamafactory/v1/plugins/data_plugins/template.py
+133
-0
src/llamafactory/v1/plugins/model_plugins/__init__.py
src/llamafactory/v1/plugins/model_plugins/__init__.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/add_token.py
src/llamafactory/v1/plugins/model_plugins/add_token.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/added_token.py
src/llamafactory/v1/plugins/model_plugins/added_token.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/initialization.py
src/llamafactory/v1/plugins/model_plugins/initialization.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py
...llamafactory/v1/plugins/model_plugins/kernels/__init__.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
+87
-0
No files found.
src/llamafactory/v1/core/trainer_utils/data_collator.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
typing
import
Any
import
torch
import
torch.nn.functional
as
F
from
torch.nn.utils.rnn
import
pad_sequence
from
torch.utils.data._utils.collate
import
default_collate
from
....extras.constants
import
IGNORE_INDEX
from
...plugins.data_plugins.template
import
Template
from
...utils.types
import
Processor
,
Tensor
def
len2culen
(
seqlens
:
"torch.Tensor"
)
->
"torch.Tensor"
:
# FIXME move to utils
"""Convert sequence lengths to cumulative sequence lengths."""
return
F
.
pad
(
torch
.
cumsum
(
seqlens
,
dim
=
0
),
(
1
,
0
)).
type
(
torch
.
int32
)
class
DataCollator
:
"""Default Data collator."""
processor
:
"Processor"
# processor name -> map to encode_messages function
def
__post_init__
(
self
):
# callback for text tokenizer
self
.
tokenizer
=
self
.
processor
.
tokenizer
if
hasattr
(
self
.
processor
,
"tokenizer"
)
else
self
.
processor
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
Tensor
]:
"""Collate features into a batch."""
batch
=
defaultdict
(
list
)
# batching features
for
feature
in
features
:
for
key
in
feature
.
keys
():
batch
[
key
].
append
(
feature
[
key
])
for
key
in
batch
.
keys
():
# process padding features
if
key
in
[
"input_ids"
,
"attention_mask"
,
"position_ids"
]:
padding_value
=
self
.
tokenizer
.
pad_token_id
if
key
==
"input_ids"
else
0
batch
[
key
]
=
pad_sequence
(
batch
[
key
],
batch_first
=
True
,
padding_value
=
padding_value
)
elif
key
in
[
"labels"
]:
batch
[
key
]
=
pad_sequence
(
batch
[
key
],
batch_first
=
True
,
padding_value
=
IGNORE_INDEX
)
else
:
batch
[
key
]
=
default_collate
(
batch
[
key
])
return
batch
# sft: messages
# dpo: chosen_messages, rejected_messages
@
dataclass
class
DefaultCollator
(
DataCollator
):
"""Example for now."""
processor
:
"Processor"
# processor name -> map to encode_messages function
template
:
"Template"
def
__call__
(
self
,
messages
:
list
[
list
[
dict
[
str
,
Any
]]])
->
dict
[
str
,
Tensor
]:
features
=
[]
# Check if data is already tokenized (contains input_ids)
if
messages
and
isinstance
(
messages
[
0
],
dict
)
and
"input_ids"
in
messages
[
0
]:
for
feature
in
messages
:
if
not
isinstance
(
feature
,
dict
):
raise
ValueError
(
f
"Expected dict but got
{
type
(
feature
)
}
"
)
tensor_feature
=
{
k
:
torch
.
tensor
(
v
,
dtype
=
torch
.
long
)
if
not
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
feature
.
items
()
}
features
.
append
(
tensor_feature
)
else
:
# raw messages need to be encoded
for
message
in
messages
:
encoded_message
=
self
.
template
.
encode_messages
(
self
.
tokenizer
,
message
)
encoded_message
=
{
k
:
torch
.
tensor
(
v
,
dtype
=
torch
.
long
)
for
k
,
v
in
encoded_message
.
items
()}
features
.
append
(
encoded_message
)
return
super
().
__call__
(
features
)
@
dataclass
class
PairwiseCollator
(
DataCollator
):
pass
@
dataclass
class
DataCollatorWithPacking
(
DefaultCollator
):
"""Data collator with packing."""
processor
:
"Processor"
template
:
"Template"
def
__call__
(
self
,
features
:
Sequence
[
dict
[
str
,
"torch.Tensor"
]])
->
dict
[
str
,
"torch.Tensor"
]:
seqlens
=
torch
.
tensor
([
len
(
feature
[
"input_ids"
])
for
feature
in
features
],
dtype
=
torch
.
long
)
batch
=
{
"cu_seqlens"
:
len2culen
(
seqlens
)}
for
input_name
in
features
[
0
].
keys
():
if
input_name
in
(
"input_ids"
,
"attention_mask"
,
"labels"
):
batch
[
input_name
]
=
torch
.
cat
([
feature
[
input_name
]
for
feature
in
features
])
else
:
batch
[
input_name
]
=
default_collate
([
feature
[
input_name
]
for
feature
in
features
])
return
batch
src/llamafactory/v1/core/trainer_utils/data_loader.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
sys
from
collections.abc
import
Generator
,
Iterator
from
dataclasses
import
dataclass
from
typing
import
Optional
from
torchdata.stateful_dataloader
import
StatefulDataLoader
from
torchdata.stateful_dataloader.sampler
import
StatefulDistributedSampler
from
...utils.batching_queue
import
BaseBatchingQueue
from
...utils.logging
import
get_logger
from
...utils.types
import
Processor
,
TorchDataset
from
.data_collator
import
DataCollator
logger
=
get_logger
(
__name__
)
# base dataloader
class
DistributedDataloader
(
StatefulDataLoader
):
"""Base Distributed DataLoader."""
dataset
:
"TorchDataset"
sampler
:
"StatefulDistributedSampler"
def
set_epoch
(
self
,
epoch
:
int
)
->
None
:
if
self
.
sampler
is
not
None
and
hasattr
(
self
.
sampler
,
"set_epoch"
):
self
.
sampler
.
set_epoch
(
epoch
)
elif
hasattr
(
self
.
dataset
,
"set_epoch"
):
self
.
dataset
.
set_epoch
(
epoch
)
@
dataclass
class
BaseDataLoader
:
"""Default DataLoader."""
processor
:
Processor
def
__init__
(
self
,
dataset
:
TorchDataset
)
->
None
:
self
.
dataset
=
dataset
# guidlines: fetch until get fixed batchsize.
# save state_dict for buffer.
# resume with state
# 1. Init stateful dataloader (tokenize)
# 2. Add to buffer (2 * max seq len per device)
# 3. Yield batch indexes (micro batch * grad acc)
# a ) non pack + non dynamic
# b ) non pack + dynamic
# c ) pack + non dynamic
# d ) pack + dynamic
def
init_dataloader
(
self
)
->
None
:
### init dataloader
pass
def
__iter__
(
self
)
->
Iterator
:
pass
def
__next__
(
self
)
->
any
:
pass
@
dataclass
class
DataLoader
:
"""Default DataLoader."""
processor
:
"Processor"
dataloader
:
"DistributedDataloader"
batching_queue
:
"BaseBatchingQueue"
collate_fn
:
"DataCollator"
num_micro_batch
:
int
=
1
length
:
int
=
0
drop_last
:
bool
=
True
def
__init__
(
self
,
dataloader
:
any
,
collate_fn
:
"DataCollator"
,
num_micro_batch
:
int
=
1
,
length
:
int
=
0
,
drop_last
:
bool
=
True
,
batching_queue
:
Optional
[
"BaseBatchingQueue"
]
=
None
,
)
->
None
:
self
.
batching_queue
=
batching_queue
self
.
num_micro_batch
=
num_micro_batch
self
.
step
=
0
self
.
_collate_fn
=
collate_fn
self
.
_dataloader
=
dataloader
self
.
_drop_last
=
drop_last
self
.
_data_iter
:
Iterator
self
.
_resume
=
False
self
.
_batch_data_iter
:
Generator
if
length
>
0
:
self
.
_length
=
length
elif
length
==
-
1
:
self
.
_length
=
sys
.
maxsize
else
:
self
.
_length
=
len
(
self
.
_dataloader
)
def
__len__
(
self
):
return
self
.
_length
def
__iter__
(
self
)
->
Iterator
:
if
not
self
.
_resume
:
self
.
step
=
0
self
.
_data_iter
=
iter
(
self
.
_dataloader
)
self
.
_batch_data_iter
=
self
.
batch_data_generator
()
self
.
_resume
=
False
return
self
def
__next__
(
self
):
return
next
(
self
.
_batch_data_iter
)
# FIXME maybe we can move origin_batch_data_generator to here
def
origin_batch_data_generator
(
self
):
"""Standard pass-through generator if do not use batching queue."""
while
True
:
if
self
.
_length
>
0
and
self
.
step
>=
self
.
_length
:
return
try
:
batch
=
[]
data
=
next
(
self
.
_data_iter
)
# split data into micro batches
for
i
in
range
(
0
,
len
(
data
),
self
.
num_micro_batch
):
micro_batch
=
data
[
i
:
i
+
self
.
num_micro_batch
]
if
self
.
_collate_fn
:
micro_batch
=
self
.
_collate_fn
(
micro_batch
)
batch
.
append
(
micro_batch
)
yield
batch
self
.
step
+=
1
except
StopIteration
:
if
self
.
step
<
self
.
_length
:
# Restart iterator to fill the requested length
self
.
_data_iter
=
iter
(
self
.
_dataloader
)
try
:
batch
=
[]
data
=
next
(
self
.
_data_iter
)
for
i
in
range
(
0
,
len
(
data
),
self
.
num_micro_batch
):
micro_batch
=
data
[
i
:
i
+
self
.
num_micro_batch
]
if
self
.
_collate_fn
:
micro_batch
=
self
.
_collate_fn
(
micro_batch
)
batch
.
append
(
micro_batch
)
yield
batch
self
.
step
+=
1
except
StopIteration
:
return
else
:
return
except
Exception
as
e
:
logger
.
error
(
f
"DataLoader origin_batch_data_generator exception:
{
e
}
"
)
raise
def
batch_data_generator
(
self
):
if
self
.
batching_queue
is
None
:
yield
from
self
.
origin_batch_data_generator
()
return
batch
=
[]
while
True
:
if
self
.
_length
and
self
.
step
>=
self
.
_length
:
return
if
self
.
batching_queue
.
is_full_filled
():
micro_batch
=
self
.
batching_queue
.
get_micro_batch
(
self
.
step
)
if
self
.
_collate_fn
:
micro_batch
=
self
.
_collate_fn
(
micro_batch
)
batch
.
append
(
micro_batch
)
if
len
(
batch
)
==
self
.
num_micro_batch
:
yield
batch
self
.
step
+=
1
batch
=
[]
try
:
processing_item
=
next
(
self
.
_data_iter
)
except
Exception
as
e
:
if
isinstance
(
e
,
StopIteration
):
if
self
.
step
<
self
.
_length
:
# call iter until reach length
self
.
_data_iter
=
iter
(
self
.
_dataloader
)
processing_item
=
next
(
self
.
_data_iter
)
elif
not
self
.
_drop_last
and
not
self
.
batching_queue
.
empty
():
while
not
self
.
batching_queue
.
empty
():
micro_batch
=
self
.
batching_queue
.
get_micro_batch
(
self
.
step
)
if
self
.
_collate_fn
:
micro_batch
=
self
.
_collate_fn
(
micro_batch
)
batch
.
append
(
micro_batch
)
if
len
(
batch
)
==
self
.
num_micro_batch
:
yield
batch
self
.
step
+=
1
batch
=
[]
while
len
(
batch
)
<
self
.
num_micro_batch
:
padding_batch
=
copy
.
deepcopy
(
micro_batch
)
padding_batch
[
"is_padded"
]
=
True
batch
.
append
(
padding_batch
)
yield
batch
self
.
step
+=
1
return
else
:
return
else
:
logger
.
error
(
f
"DataLoader iter data exception:
{
e
}
"
)
raise
# put processing_item to buffer
if
isinstance
(
processing_item
,
dict
):
processing_item
=
[
processing_item
]
for
item
in
processing_item
:
self
.
batching_queue
.
put_item
(
item
)
def
state_dict
(
self
):
# save state
state
=
self
.
__dict__
.
copy
()
# remove internal fields
for
k
in
list
(
state
.
keys
()):
if
k
.
startswith
(
"_"
):
del
state
[
k
]
# save dataloader state
if
hasattr
(
self
.
_dataloader
,
"state_dict"
):
state
[
"dataloader_state"
]
=
self
.
_dataloader
.
state_dict
()
elif
hasattr
(
self
.
_dataloader
,
"__getstate__"
):
state
[
"dataloader_state"
]
=
self
.
_dataloader
.
__getstate__
()
batching_strategy
=
getattr
(
self
,
"batching_strategy"
,
None
)
if
batching_strategy
and
hasattr
(
batching_strategy
,
"state_dict"
):
state
[
"batching_strategy_state"
]
=
batching_strategy
.
state_dict
()
if
"batching_strategy"
in
state
:
del
state
[
"batching_strategy"
]
return
copy
.
deepcopy
(
state
)
def
load_state_dict
(
self
,
state
:
dict
[
str
,
any
]):
if
state
[
"num_micro_batch"
]
!=
self
.
num_micro_batch
:
logger
.
warning
(
f
"num_micro_batch changed: [
{
state
[
'num_micro_batch'
]
}
->
{
self
.
num_micro_batch
}
], will clear prefetch buffer"
)
del
state
[
"num_micro_batch"
]
self
.
__dict__
.
update
(
state
)
self
.
_resume
=
True
if
hasattr
(
self
.
_dataloader
,
"load_state_dict"
):
self
.
_dataloader
.
load_state_dict
(
state
[
"dataloader_state"
])
elif
hasattr
(
self
.
_dataloader
,
"__getstate__"
):
self
.
_dataloader
.
__setstate__
(
state
[
"dataloader_state"
])
if
"batching_strategy_state"
in
state
:
batching_strategy
=
getattr
(
self
,
"batching_strategy"
,
None
)
if
batching_strategy
:
batching_strategy
.
load_state_dict
(
state
[
"batching_strategy_state"
])
del
state
[
"batching_strategy_state"
]
self
.
_data_iter
=
iter
(
self
.
_dataloader
)
self
.
_batch_data_iter
=
self
.
batch_data_generator
()
def
set_epoch
(
self
,
epoch
:
int
)
->
None
:
if
hasattr
(
self
.
_dataloader
,
"set_epoch"
):
self
.
_dataloader
.
set_epoch
(
epoch
)
src/llamafactory/v1/core/trainer_utils/lr_scheduler.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/core/utils/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/core/utils/batching.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batching utils supports stateful dataloader.
1. Init stateful dataloader (tokenize)
2. Add to buffer
3. Yield batch indexes (micro batch * grad acc)
a) non pack + non dynamic
b) non pack + dynamic
c) pack + non dynamic
d) pack + dynamic
"""
from
collections.abc
import
Iterator
from
typing
import
Any
import
torch
from
torch.utils.data
import
default_collate
from
torchdata.stateful_dataloader
import
StatefulDataLoader
from
torchdata.stateful_dataloader.sampler
import
StatefulDistributedSampler
from
...accelerator.interface
import
Dim
,
DistributedInterface
from
...config
import
BatchingStrategy
from
...utils
import
logging
from
...utils.helper
import
pad_and_truncate
from
...utils.objects
import
StatefulBuffer
from
...utils.types
import
BatchInfo
,
BatchInput
,
ModelInput
,
TorchDataset
from
.rendering
import
Renderer
logger
=
logging
.
get_logger
(
__name__
)
def
default_collate_fn
(
buffer
:
StatefulBuffer
,
batch_info
:
BatchInfo
)
->
list
[
BatchInput
]
|
None
:
micro_batch_size
=
batch_info
[
"micro_batch_size"
]
num_micro_batch
=
batch_info
[
"num_micro_batch"
]
cutoff_len
=
batch_info
[
"cutoff_len"
]
batch_size
=
micro_batch_size
*
num_micro_batch
if
len
(
buffer
)
<
batch_size
:
return
None
samples
=
buffer
.
get
(
batch_size
)
batch
=
[]
for
i
in
range
(
num_micro_batch
):
micro_batch
=
samples
[
i
*
micro_batch_size
:
(
i
+
1
)
*
micro_batch_size
]
batch
.
append
(
default_collate
(
pad_and_truncate
(
micro_batch
,
cutoff_len
)))
return
batch
class
BatchGenerator
(
Iterator
):
def
__init__
(
self
,
dataset
:
TorchDataset
,
renderer
:
Renderer
,
micro_batch_size
:
int
=
1
,
global_batch_size
:
int
|
None
=
None
,
cutoff_len
:
int
=
2048
,
batching_workers
:
int
=
0
,
batching_strategy
:
BatchingStrategy
=
BatchingStrategy
.
NORMAL
,
pin_memory
:
bool
=
True
,
drop_last
:
bool
=
True
,
seed
:
int
=
42
,
)
->
None
:
self
.
dataset
=
dataset
self
.
renderer
=
renderer
self
.
micro_batch_size
=
micro_batch_size
self
.
global_batch_size
=
global_batch_size
self
.
cutoff_len
=
cutoff_len
self
.
batching_workers
=
batching_workers
self
.
batching_strategy
=
batching_strategy
self
.
pin_memory
=
pin_memory
self
.
drop_last
=
drop_last
self
.
seed
=
seed
# TODO: support length and infinity
dp_size
=
DistributedInterface
().
get_world_size
(
Dim
.
DP
)
if
self
.
global_batch_size
is
None
:
self
.
global_batch_size
=
dp_size
*
micro_batch_size
self
.
num_micro_batch
=
1
elif
self
.
global_batch_size
%
(
dp_size
*
micro_batch_size
)
==
0
:
self
.
num_micro_batch
=
global_batch_size
//
dp_size
//
micro_batch_size
else
:
raise
ValueError
(
"Global batch size must be divisible by DP size and micro batch size. "
f
"Got
{
global_batch_size
}
% (
{
dp_size
}
*
{
micro_batch_size
}
) != 0."
)
if
not
self
.
drop_last
:
raise
ValueError
(
"Drop last must be True."
)
self
.
_init_data_provider
()
self
.
_is_resuming
:
bool
=
False
self
.
_data_iter
=
iter
(
self
.
_data_provider
)
self
.
_buffer
=
StatefulBuffer
()
self
.
_batch_info
:
BatchInfo
=
{
"micro_batch_size"
:
self
.
micro_batch_size
,
"num_micro_batch"
:
self
.
num_micro_batch
,
"cutoff_len"
:
self
.
cutoff_len
,
"data_iter"
:
self
.
_data_iter
,
}
logger
.
info_rank0
(
f
"Init unified data loader with global batch size
{
self
.
global_batch_size
}
, "
f
"micro batch size
{
self
.
micro_batch_size
}
, "
f
"num micro batch
{
self
.
num_micro_batch
}
, "
f
"cutoff len
{
self
.
cutoff_len
}
, "
f
"batching workers
{
self
.
batching_workers
}
, "
f
"batching strategy
{
self
.
batching_strategy
}
."
)
def
_init_data_provider
(
self
)
->
None
:
if
len
(
self
.
dataset
)
!=
-
1
:
sampler
=
StatefulDistributedSampler
(
self
.
dataset
,
num_replicas
=
DistributedInterface
().
get_world_size
(
Dim
.
DP
),
rank
=
DistributedInterface
().
get_rank
(
Dim
.
DP
),
shuffle
=
True
,
seed
=
self
.
seed
,
drop_last
=
self
.
drop_last
,
)
else
:
raise
NotImplementedError
(
"Iterable dataset is not supported yet."
)
generato_seed
=
torch
.
Generator
()
generato_seed
.
manual_seed
(
self
.
seed
)
self
.
_data_provider
=
StatefulDataLoader
(
self
.
dataset
,
batch_size
=
self
.
micro_batch_size
*
self
.
num_micro_batch
,
sampler
=
sampler
,
num_workers
=
self
.
batching_workers
,
collate_fn
=
self
.
renderer
.
process_samples
,
pin_memory
=
self
.
pin_memory
,
pin_memory_device
=
DistributedInterface
().
current_device
.
type
,
drop_last
=
self
.
drop_last
,
generator
=
generato_seed
,
)
if
self
.
batching_strategy
==
BatchingStrategy
.
NORMAL
:
self
.
_length
=
len
(
self
.
_data_provider
)
else
:
from
...plugins.trainer_plugins.batching
import
BatchingPlugin
self
.
_length
=
BatchingPlugin
(
self
.
batching_strategy
).
compute_length
(
self
.
_data_provider
)
raise
NotImplementedError
(
"Batching strategy other than NORMAL is not supported yet."
)
def
__len__
(
self
)
->
int
:
return
self
.
_length
def
__iter__
(
self
):
if
not
self
.
_is_resuming
:
self
.
_buffer
.
clear
()
self
.
_buffer_tokens
=
0
self
.
_data_iter
=
iter
(
self
.
_data_provider
)
self
.
_is_resuming
=
False
return
self
def
__next__
(
self
):
self
.
_fill_buffer
()
batch
=
self
.
_generate_batch
()
if
batch
is
None
:
raise
StopIteration
return
batch
def
_fill_buffer
(
self
)
->
None
:
if
self
.
batching_strategy
==
BatchingStrategy
.
NORMAL
:
while
len
(
self
.
_buffer
)
<
self
.
micro_batch_size
*
self
.
num_micro_batch
:
try
:
samples
:
list
[
ModelInput
]
=
next
(
self
.
_data_iter
)
except
StopIteration
:
break
self
.
_buffer
.
put
(
samples
)
else
:
from
...plugins.trainer_plugins.batching
import
BatchingPlugin
BatchingPlugin
(
self
.
batching_strategy
).
fill_buffer
(
self
.
_buffer
,
self
.
_batch_info
)
def
_generate_batch
(
self
)
->
list
[
BatchInput
]
|
None
:
if
self
.
batching_strategy
==
BatchingStrategy
.
NORMAL
:
return
default_collate_fn
(
self
.
_buffer
,
self
.
_batch_info
)
else
:
from
...plugins.trainer_plugins.batching
import
BatchingPlugin
return
BatchingPlugin
(
self
.
batching_strategy
).
generate_batch
(
self
.
_buffer
,
self
.
_batch_info
)
def
state_dict
(
self
)
->
dict
[
str
,
Any
]:
return
{
"buffer"
:
self
.
_buffer
,
"buffer_tokens"
:
self
.
_buffer_tokens
,
"data_provider"
:
self
.
_data_provider
.
state_dict
(),
}
def
load_state_dict
(
self
,
state
:
dict
[
str
,
Any
])
->
None
:
self
.
_buffer
=
state
[
"buffer"
]
self
.
_buffer_tokens
=
state
[
"buffer_tokens"
]
self
.
_data_provider
.
load_state_dict
(
state
[
"data_provider"
])
self
.
_is_resuming
=
True
def
set_epoch
(
self
,
epoch
:
int
)
->
None
:
if
hasattr
(
self
.
_data_provider
.
sampler
,
"set_epoch"
):
self
.
_data_provider
.
sampler
.
set_epoch
(
epoch
)
if
__name__
==
"__main__"
:
"""
python -m llamafactory.v1.core.utils.batching
\
--model llamafactory/tiny-random-qwen2.5
\
--train_dataset data/v1_sft_demo.yaml
\
--micro_batch_size 2
\
--global_batch_size 4
\
--batching_workers 0
"""
from
...config.arg_parser
import
get_args
from
..data_engine
import
DataEngine
from
..model_engine
import
ModelEngine
model_args
,
data_args
,
training_args
,
_
=
get_args
()
data_engine
=
DataEngine
(
data_args
.
train_dataset
)
model_engine
=
ModelEngine
(
model_args
=
model_args
)
batch_generator
=
BatchGenerator
(
data_engine
,
model_engine
.
renderer
,
micro_batch_size
=
training_args
.
micro_batch_size
,
global_batch_size
=
training_args
.
global_batch_size
,
cutoff_len
=
training_args
.
cutoff_len
,
batching_workers
=
training_args
.
batching_workers
,
batching_strategy
=
training_args
.
batching_strategy
,
)
for
batch
in
batch_generator
:
print
(
batch
)
print
(
len
(
batch
))
print
(
batch
[
0
][
"input_ids"
].
shape
)
break
src/llamafactory/v1/core/utils/callback.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/core/utils/inference_engine.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
os
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
AsyncGenerator
from
threading
import
Thread
import
torch
from
transformers
import
AsyncTextIteratorStreamer
from
...accelerator.interface
import
DistributedInterface
from
...config
import
ModelArguments
,
SampleArguments
from
...utils.helper
import
get_tokenizer
from
...utils.types
import
HFModel
,
Message
,
Sample
,
TorchDataset
from
.rendering
import
Renderer
class
BaseEngine
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
args
:
SampleArguments
,
model_args
:
ModelArguments
,
model
:
HFModel
,
renderer
:
Renderer
,
)
->
None
:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
...
@
abstractmethod
async
def
generate
(
self
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
)
->
AsyncGenerator
[
str
,
None
]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@
abstractmethod
async
def
batch_infer
(
self
,
dataset
:
TorchDataset
)
->
list
[
Sample
]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class
HuggingFaceEngine
(
BaseEngine
):
def
__init__
(
self
,
args
:
SampleArguments
,
model_args
:
ModelArguments
,
model
:
HFModel
,
renderer
:
Renderer
,
)
->
None
:
self
.
args
=
args
self
.
model_args
=
model_args
self
.
model
=
model
self
.
renderer
=
renderer
self
.
semaphore
=
asyncio
.
Semaphore
(
int
(
os
.
getenv
(
"MAX_CONCURRENT"
,
"1"
)))
@
torch
.
inference_mode
()
async
def
generate
(
self
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
)
->
AsyncGenerator
[
str
,
None
]:
async
with
self
.
semaphore
:
model_inputs
=
self
.
renderer
.
render_messages
(
messages
,
tools
,
is_generate
=
True
)
streamer
=
AsyncTextIteratorStreamer
(
tokenizer
=
get_tokenizer
(
self
.
renderer
.
processor
),
skip_prompt
=
True
,
skip_special_tokens
=
True
,
# TODO: configurable
)
device
=
DistributedInterface
().
current_device
kwargs
=
{
"input_ids"
:
torch
.
tensor
([
model_inputs
[
"input_ids"
]]).
to
(
device
),
"attention_mask"
:
torch
.
tensor
([
model_inputs
[
"attention_mask"
]]).
to
(
device
),
"max_new_tokens"
:
self
.
args
.
max_new_tokens
,
"streamer"
:
streamer
,
}
thread
=
Thread
(
target
=
self
.
model
.
generate
,
kwargs
=
kwargs
,
daemon
=
True
)
thread
.
start
()
async
for
token
in
streamer
:
yield
token
async
def
batch_infer
(
self
,
dataset
:
TorchDataset
)
->
list
[
Sample
]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise
NotImplementedError
(
"Batch infer is not implemented."
)
src/llamafactory/v1/core/utils/rendering.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rendering utils.
How to use:
renderer = Renderer(template, processor)
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
renderer.parse_message(text: str) -> Message
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
"""
import
numpy
as
np
from
...utils.constants
import
IGNORE_INDEX
from
...utils.helper
import
get_tokenizer
from
...utils.types
import
Message
,
ModelInput
,
Processor
,
Sample
def
render_chatml_messages
(
processor
:
Processor
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
,
is_generate
:
bool
=
False
,
)
->
ModelInput
:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
"""
tokenizer
=
get_tokenizer
(
processor
)
input_ids
,
labels
,
loss_weights
=
[],
[],
[]
for
message
in
messages
:
temp_str
=
"<|im_start|>"
+
message
[
"role"
]
+
"
\n
"
for
content
in
message
[
"content"
]:
if
content
[
"type"
]
==
"text"
:
temp_str
+=
content
[
"value"
]
else
:
raise
ValueError
(
f
"Unsupported content type:
{
content
[
'type'
]
}
"
)
temp_str
+=
"<|im_end|>
\n
"
temp_weight
=
message
.
get
(
"loss_weight"
,
1.0
if
message
[
"role"
]
==
"assistant"
else
0.0
)
temp_ids
=
tokenizer
.
encode
(
temp_str
,
add_special_tokens
=
False
)
input_ids
.
extend
(
temp_ids
)
loss_weights
.
extend
([
temp_weight
]
*
len
(
temp_ids
))
if
temp_weight
>
1e-6
:
labels
.
extend
(
temp_ids
)
else
:
labels
.
extend
([
IGNORE_INDEX
]
*
len
(
temp_ids
))
if
is_generate
:
temp_ids
=
tokenizer
.
encode
(
"<|im_start|>assistant
\n
"
,
add_special_tokens
=
False
)
input_ids
.
extend
(
temp_ids
)
loss_weights
.
extend
([
0.0
]
*
len
(
temp_ids
))
labels
.
extend
([
IGNORE_INDEX
]
*
len
(
temp_ids
))
return
ModelInput
(
input_ids
=
input_ids
,
attention_mask
=
[
1
]
*
len
(
input_ids
),
labels
=
labels
,
loss_weights
=
loss_weights
,
)
def
parse_chatml_message
(
generated_text
:
str
)
->
Message
:
"""Parse a message in ChatML format.
Args:
generated_text (str): The generated text in ChatML format.
Returns:
Message: The parsed message.
"""
return
Message
(
role
=
"assistant"
,
content
=
[{
"type"
:
"text"
,
"value"
:
generated_text
}])
class
Renderer
:
def
__init__
(
self
,
template
:
str
,
processor
:
Processor
):
self
.
template
=
template
self
.
processor
=
processor
def
render_messages
(
self
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
,
is_generate
:
bool
=
False
,
enable_thinking
:
bool
=
False
,
)
->
ModelInput
:
"""Apply template to messages and convert them to model input.
Args:
messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False.
enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False.
Returns:
ModelInput: The rendered model input.
"""
if
self
.
template
==
"chatml"
:
return
render_chatml_messages
(
self
.
processor
,
messages
,
tools
,
is_generate
)
else
:
from
...plugins.model_plugins.rendering
import
RenderingPlugin
return
RenderingPlugin
(
self
.
template
).
render_messages
(
self
.
processor
,
messages
,
tools
,
is_generate
,
enable_thinking
)
def
parse_message
(
self
,
generated_text
:
str
)
->
Message
:
"""Parse a message in the template format.
Args:
generated_text (str): The generated text in the template format.
Returns:
Message: The parsed message.
"""
if
self
.
template
==
"chatml"
:
return
parse_chatml_message
(
generated_text
)
else
:
from
...plugins.model_plugins.rendering
import
RenderingPlugin
return
RenderingPlugin
(
self
.
template
).
parse_message
(
generated_text
)
def
process_samples
(
self
,
samples
:
list
[
Sample
])
->
list
[
ModelInput
]:
"""Process samples to model input.
Args:
samples (list[Sample]): The samples to process.
Returns:
list[ModelInput]: The processed model inputs.
"""
model_inputs
=
[]
for
sample
in
samples
:
if
"messages"
in
sample
:
model_input
=
self
.
render_messages
(
sample
[
"messages"
],
sample
.
get
(
"tools"
))
elif
"chosen_messages"
in
sample
and
"rejected_messages"
in
sample
:
chosen_input
=
self
.
render_messages
(
sample
[
"chosen_messages"
],
sample
.
get
(
"tools"
))
rejected_input
=
self
.
render_messages
(
sample
[
"rejected_messages"
],
sample
.
get
(
"tools"
))
chosen_input
[
"token_type_ids"
]
=
[
1
]
*
len
(
chosen_input
[
"input_ids"
])
rejected_input
[
"token_type_ids"
]
=
[
2
]
*
len
(
rejected_input
[
"input_ids"
])
model_input
=
ModelInput
(
input_ids
=
chosen_input
[
"input_ids"
]
+
rejected_input
[
"input_ids"
],
attention_mask
=
chosen_input
[
"attention_mask"
]
+
rejected_input
[
"attention_mask"
],
labels
=
chosen_input
[
"labels"
]
+
rejected_input
[
"labels"
],
loss_weights
=
chosen_input
[
"loss_weights"
]
+
rejected_input
[
"loss_weights"
],
token_type_ids
=
chosen_input
[
"token_type_ids"
]
+
rejected_input
[
"token_type_ids"
],
)
if
"position_ids"
in
chosen_input
:
model_input
[
"position_ids"
]
=
np
.
concatenate
(
[
chosen_input
[
"position_ids"
],
rejected_input
[
"position_ids"
]],
axis
=-
1
)
else
:
raise
ValueError
(
"No valid messages or chosen_messages/rejected_messages found in sample."
)
if
"extra_info"
in
sample
:
model_input
[
"extra_info"
]
=
sample
[
"extra_info"
]
if
"_dataset_name"
in
sample
:
model_input
[
"_dataset_name"
]
=
sample
[
"_dataset_name"
]
model_inputs
.
append
(
model_input
)
return
model_inputs
src/llamafactory/v1/launcher.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
from
..extras.env
import
VERSION
,
print_env
USAGE
=
(
"-"
*
70
+
"
\n
"
+
"| Usage: |
\n
"
+
"| llamafactory-cli sft -h: train models |
\n
"
+
"| llamafactory-cli version: show version info |
\n
"
+
"| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |
\n
"
+
"-"
*
70
)
WELCOME
=
(
"-"
*
58
+
"
\n
"
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
" "
*
56
+
"|
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"-"
*
58
)
def
launch
():
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
1
else
"help"
if
command
==
"sft"
:
# train command will fallback to sft command
from
.trainers.sft_trainer
import
run_sft
run_sft
()
elif
command
==
"env"
:
print_env
()
elif
command
==
"version"
:
print
(
WELCOME
)
elif
command
==
"help"
:
print
(
USAGE
)
else
:
print
(
f
"Unknown command:
{
command
}
.
\n
{
USAGE
}
"
)
if
__name__
==
"__main__"
:
pass
src/llamafactory/v1/plugins/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/data_plugins/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/data_plugins/converter.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Any
,
Literal
,
NotRequired
,
TypedDict
from
...utils
import
logging
from
...utils.plugin
import
BasePlugin
from
...utils.types
import
DPOSample
,
Sample
,
SFTSample
logger
=
logging
.
get_logger
(
__name__
)
class
AlpacaSample
(
TypedDict
,
total
=
False
):
system
:
NotRequired
[
str
]
instruction
:
str
input
:
NotRequired
[
str
]
output
:
str
SharegptMessage
=
TypedDict
(
"SharegptMessage"
,
{
"from"
:
Literal
[
"human"
,
"gpt"
,
"system"
,
"function_call"
,
"observation"
],
"value"
:
str
}
)
class
SharegptSample
(
TypedDict
,
total
=
False
):
conversations
:
list
[
SharegptMessage
]
tools
:
NotRequired
[
str
]
class
OpenaiMessage
(
TypedDict
,
total
=
False
):
role
:
Literal
[
"user"
,
"assistant"
,
"tool"
]
content
:
str
class
OpenaiSample
(
TypedDict
,
total
=
False
):
messages
:
list
[
OpenaiMessage
]
class
PairSample
(
TypedDict
,
total
=
False
):
chosen
:
list
[
OpenaiMessage
]
rejected
:
list
[
OpenaiMessage
]
class
DataConverterPlugin
(
BasePlugin
):
"""Plugin for data converters."""
def
__call__
(
self
,
raw_sample
:
dict
[
str
,
Any
])
->
Sample
:
return
super
().
__call__
(
raw_sample
)
@
DataConverterPlugin
(
"alpaca"
).
register
def
alpaca_converter
(
raw_sample
:
AlpacaSample
)
->
SFTSample
:
"""Convert Alpaca sample to SFT sample.
See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en
Args:
raw_sample (AlpacaSample): Alpaca sample.
Returns:
SFTSample: SFT sample.
"""
messages
=
[]
if
"system"
in
raw_sample
:
messages
.
append
(
{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
raw_sample
[
"system"
]}],
"loss_weight"
:
0.0
}
)
if
"instruction"
in
raw_sample
or
"input"
in
raw_sample
:
messages
.
append
(
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"value"
:
raw_sample
.
get
(
"instruction"
,
""
)
+
raw_sample
.
get
(
"input"
,
""
)}
],
"loss_weight"
:
0.0
,
}
)
if
"output"
in
raw_sample
:
messages
.
append
(
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
raw_sample
[
"output"
]}],
"loss_weight"
:
1.0
}
)
return
{
"messages"
:
messages
}
@
DataConverterPlugin
(
"sharegpt"
).
register
def
sharegpt_converter
(
raw_sample
:
SharegptSample
)
->
SFTSample
:
"""Convert ShareGPT sample to SFT sample.
See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en
Args:
raw_sample (SharegptSample): ShareGPT sample.
Returns:
SFTSample: SFT sample.
"""
tag_mapping
=
{
"system"
:
"system"
,
"human"
:
"user"
,
"gpt"
:
"assistant"
,
"observation"
:
"tool"
,
"function_call"
:
"assistant"
,
}
messages
=
[]
tools
=
raw_sample
.
get
(
"tools"
,
""
)
for
message
in
raw_sample
.
get
(
"conversations"
,
[]):
tag
=
message
[
"from"
]
if
tag
not
in
tag_mapping
:
logger
.
warning_rank0
(
f
"Unsupported role tag
{
tag
}
in message:
{
message
}
"
)
elif
tag
==
"function_call"
:
messages
.
append
(
{
"role"
:
"assistant"
,
"content"
:
[{
"type"
:
"tool_calls"
,
"value"
:
message
[
"value"
]}],
"loss_weight"
:
1.0
,
}
)
else
:
messages
.
append
(
{
"role"
:
tag_mapping
[
tag
],
"content"
:
[{
"type"
:
"text"
,
"value"
:
message
[
"value"
]}],
"loss_weight"
:
1.0
if
tag
==
"gpt"
else
0.0
,
}
)
if
tools
:
if
messages
and
messages
[
0
][
"role"
]
==
"system"
:
messages
[
0
][
"content"
].
append
({
"type"
:
"tools"
,
"value"
:
tools
})
else
:
messages
.
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
[{
"type"
:
"tools"
,
"value"
:
tools
}],
"loss_weight"
:
0.0
})
return
{
"messages"
:
messages
}
@
DataConverterPlugin
(
"pair"
).
register
def
pair_converter
(
raw_sample
:
PairSample
)
->
DPOSample
:
"""Convert Pair sample to DPO sample.
See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs
Args:
raw_sample (PairSample): pair sample with chosen, rejected fields.
Returns:
DPOSample: DPO sample with chosen_messages and rejected_messages.
"""
def
process_message
(
raw_messages
:
list
[
OpenaiMessage
]):
messages
=
[]
for
message
in
raw_messages
:
messages
.
append
(
{
"role"
:
message
[
"role"
],
"content"
:
[{
"type"
:
"text"
,
"value"
:
message
[
"content"
]}],
"loss_weight"
:
1.0
if
message
[
"role"
]
==
"assistant"
else
0.0
,
}
)
return
messages
chosen_messages
=
process_message
(
raw_sample
.
get
(
"chosen"
,
[]))
rejected_messages
=
process_message
(
raw_sample
.
get
(
"rejected"
,
[]))
return
{
"chosen_messages"
:
chosen_messages
,
"rejected_messages"
:
rejected_messages
}
src/llamafactory/v1/plugins/data_plugins/loader.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
from
typing
import
Any
,
Literal
from
datasets
import
load_dataset
from
...utils.plugin
import
BasePlugin
from
...utils.types
import
DatasetInfo
,
HFDataset
class
DataLoaderPlugin
(
BasePlugin
):
"""Plugin for loading dataset."""
def
load
(
self
,
dataset_info
:
DatasetInfo
)
->
HFDataset
:
path
=
dataset_info
[
"path"
]
split
=
dataset_info
.
get
(
"split"
,
"train"
)
streaming
=
dataset_info
.
get
(
"streaming"
,
False
)
return
super
().
__call__
(
path
,
split
,
streaming
)
def
_get_builder_name
(
path
:
str
)
->
Literal
[
"arrow"
,
"csv"
,
"json"
,
"parquet"
,
"text"
]:
"""Get dataset builder name.
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
filetype
=
os
.
path
.
splitext
(
path
)[
-
1
][
1
:]
if
filetype
in
[
"arrow"
,
"csv"
,
"json"
,
"jsonl"
,
"parquet"
,
"txt"
]:
return
filetype
.
replace
(
"jsonl"
,
"json"
).
replace
(
"txt"
,
"text"
)
else
:
raise
ValueError
(
f
"Unknown dataset filetype:
{
filetype
}
."
)
@
DataLoaderPlugin
(
"local"
).
register
def
load_data_from_file
(
filepath
:
str
,
split
:
str
,
streaming
:
bool
)
->
HFDataset
:
if
os
.
path
.
isdir
(
filepath
):
filetype
=
_get_builder_name
(
os
.
listdir
(
filepath
)[
0
])
dataset
=
load_dataset
(
filetype
,
data_dir
=
filepath
,
split
=
split
)
elif
os
.
path
.
isfile
(
filepath
):
filetype
=
_get_builder_name
(
filepath
)
dataset
=
load_dataset
(
filetype
,
data_files
=
filepath
,
split
=
split
)
else
:
raise
ValueError
(
f
"Can not load dataset from
{
filepath
}
."
)
if
streaming
:
# faster when data is streamed from local files
dataset
=
dataset
.
to_iterable_dataset
()
return
dataset
class
DataIndexPlugin
(
BasePlugin
):
"""Plugin for adjusting dataset index."""
def
adjust_data_index
(
self
,
data_index
:
list
[
tuple
[
str
,
int
]],
size
:
int
|
None
,
weight
:
float
|
None
)
->
list
[
tuple
[
str
,
int
]]:
"""Adjust dataset index by size and weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if
size
is
not
None
:
data_index
=
random
.
choices
(
data_index
,
k
=
size
)
if
weight
is
not
None
:
data_index
=
random
.
choices
(
data_index
,
k
=
int
(
len
(
data_index
)
*
weight
))
return
data_index
class
DataSelectorPlugin
(
BasePlugin
):
"""Plugin for selecting dataset samples."""
def
select
(
self
,
data_index
:
list
[
tuple
[
str
,
int
]],
index
:
slice
|
list
[
int
]
|
Any
)
->
tuple
[
str
,
int
]
|
list
[
tuple
[
str
,
int
]]:
"""Select dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if
isinstance
(
index
,
slice
):
return
[
data_index
[
i
]
for
i
in
range
(
*
index
.
indices
(
len
(
data_index
)))]
elif
isinstance
(
index
,
list
):
return
[
data_index
[
i
]
for
i
in
index
]
else
:
raise
ValueError
(
f
"Invalid index type
{
type
(
index
)
}
."
)
src/llamafactory/v1/plugins/data_plugins/template.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
@
dataclass
class
Template
:
user_template
:
str
assistant_template
:
str
system_template
:
str
def
render_message
(
self
,
message
:
dict
[
str
,
str
])
->
str
:
return
self
.
user_template
.
format
(
**
message
)
@
dataclass
class
QwenTemplate
:
message_template
:
str
=
"<|im_start|>{role}
\n
{content}<|im_end|>
\n
"
# FIXME if role: tool
thinking_template
:
str
=
"<think>
\n
{content}
\n
</think>
\n\n
"
def
_extract_content
(
self
,
content_data
:
str
|
list
[
dict
[
str
,
str
]])
->
str
:
if
isinstance
(
content_data
,
str
):
return
content_data
.
strip
()
if
isinstance
(
content_data
,
list
):
parts
=
[]
for
item
in
content_data
:
if
item
.
get
(
"type"
)
==
"text"
:
parts
.
append
(
item
.
get
(
"value"
,
""
))
elif
item
.
get
(
"type"
)
==
"image_url"
:
pass
return
"
\n
"
.
join
(
parts
).
strip
()
return
""
def
render_message
(
self
,
message
:
dict
[
str
,
str
|
list
[
dict
[
str
,
str
]]])
->
str
:
role
=
message
[
"role"
]
content
=
self
.
_extract_content
(
message
.
get
(
"content"
,
""
))
if
role
==
"assistant"
:
reasoning_content
=
message
.
get
(
"reasoning_content"
,
""
)
if
reasoning_content
:
reasoning_content
=
self
.
thinking_template
.
format
(
content
=
str
(
reasoning_content
).
strip
())
return
self
.
message_template
.
format
(
role
=
"assistant"
,
content
=
reasoning_content
+
content
)
else
:
return
self
.
message_template
.
format
(
role
=
role
,
content
=
content
)
def
encode_messages
(
self
,
tokenizer
,
messages
:
list
[
dict
[
str
,
str
]],
max_seq_len
:
int
=
8192
)
->
any
:
"""Encode one message."""
input_ids
,
attention_mask
,
labels
=
[],
[],
[]
for
message
in
messages
:
content_str
=
self
.
render_message
(
message
)
content_ids
=
tokenizer
.
encode
(
content_str
,
add_special_tokens
=
False
)
input_ids
+=
content_ids
attention_mask
+=
[
1
]
*
len
(
content_ids
)
if
hasattr
(
message
,
"loss_weight"
):
loss_weight
=
message
[
"loss_weight"
]
else
:
loss_weight
=
1
if
message
[
"role"
]
==
"assistant"
else
0
if
loss_weight
==
1
:
labels
+=
content_ids
else
:
labels
+=
[
-
100
]
*
len
(
content_ids
)
model_inputs
=
{
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"labels"
:
labels
}
model_inputs
.
update
({
"position_ids"
:
list
(
range
(
len
(
input_ids
)))})
model_inputs
=
{
k
:
v
[
-
max_seq_len
:]
for
k
,
v
in
model_inputs
.
items
()}
return
model_inputs
if
__name__
==
"__main__"
:
def
to_qwen3_messages
(
template
:
QwenTemplate
,
messages
:
list
[
dict
]):
out
=
[]
for
m
in
messages
:
role
=
m
[
"role"
]
content
=
template
.
_extract_content
(
m
.
get
(
"content"
,
""
))
if
role
==
"assistant"
:
reasoning
=
(
m
.
get
(
"reasoning_content"
)
or
""
).
strip
()
if
reasoning
:
content
=
template
.
thinking_template
.
format
(
content
=
reasoning
)
+
content
out
.
append
({
"role"
:
role
,
"content"
:
content
})
return
out
from
transformers
import
AutoTokenizer
tok
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen3-30B-A3B-Thinking-2507"
,
trust_remote_code
=
True
,
)
test_messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"1+1等于几?"
},
{
"type"
:
"text"
,
"text"
:
"2+2等于几?"
}],
},
{
"role"
:
"assistant"
,
"reasoning_content"
:
"这是一个简单的数学问题。1加1的结果是2。"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"1+1=2"
},
{
"type"
:
"text"
,
"text"
:
"2+2=4"
}],
},
]
template
=
QwenTemplate
()
rendered_custom
=
""
.
join
([
template
.
render_message
(
m
)
for
m
in
test_messages
])
qwen3_messages
=
to_qwen3_messages
(
template
,
test_messages
)
rendered_hf
=
tok
.
apply_chat_template
(
qwen3_messages
,
tokenize
=
False
,
add_generation_prompt
=
False
)
print
(
"==== custom ===="
)
print
(
rendered_custom
)
print
(
"==== hf ===="
)
print
(
rendered_hf
)
assert
rendered_custom
.
strip
()
==
rendered_hf
.
strip
(),
"Rendered text mismatch"
ids_custom
=
tok
.
encode
(
rendered_custom
,
add_special_tokens
=
False
)
ids_hf
=
tok
.
apply_chat_template
(
qwen3_messages
,
tokenize
=
True
,
add_generation_prompt
=
False
)
assert
ids_custom
==
ids_hf
,
f
"Token ids mismatch: custom=
{
len
(
ids_custom
)
}
hf=
{
len
(
ids_hf
)
}
"
src/llamafactory/v1/plugins/model_plugins/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/add_token.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/added_token.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/initialization.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/kernels/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/kernels/base.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of base kernel class.
Init Phase:
1. Define base kernel class.
2. Define abstract methods.
"""
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
from
....accelerator.helper
import
DeviceType
,
get_current_accelerator
from
....utils.types
import
HFModel
class
BaseKernel
(
ABC
):
r
"""Base class for all kernel implementations.
Subclasses must implement the abstract methods and define the required class attributes.
"""
_kernel_id
:
Any
=
""
# kernel ID, any hashable value to identify a kernel implementation
_device
:
DeviceType
=
DeviceType
.
CPU
# "cuda", "npu", "cpu", etc.
@
classmethod
def
get_kernel_id
(
cls
)
->
str
:
r
"""Returns the unique identifier for the kernel."""
return
cls
.
_kernel_id
@
classmethod
def
get_device
(
cls
)
->
str
:
r
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return
cls
.
_device
@
classmethod
def
check_deps
(
cls
)
->
bool
:
r
"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
.. note::
In explicit mode, if a user specifies an implementation but this check fails,
it should raise an error instead of silently switching.
Kernels can override this method to implement custom dependency checks.
"""
if
cls
.
_device
!=
get_current_accelerator
().
type
:
return
False
return
True
@
classmethod
@
abstractmethod
def
apply
(
cls
,
**
kwargs
)
->
HFModel
:
r
"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
Returns:
HFModel: The model with the kernel applied.
Raises:
RuntimeError: If the kernel dependencies are not met.
NotImplementedError: If the method is not implemented by the subclass.
Example:
>>> from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
>>> model = HFModel(config=config)
>>> model = apply_kernel(model=model, kernel_id="npu_fused_moe")
"""
if
not
cls
.
check_deps
():
raise
RuntimeError
(
f
"
{
cls
.
__name__
}
is not available but
{
cls
.
__name__
}
kernel was called."
)
raise
NotImplementedError
Prev
1
…
8
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment