Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1387 additions
and
0 deletions
+1387
-0
src/llamafactory/v1/plugins/trainer_plugins/distributed/__init__.py
...actory/v1/plugins/trainer_plugins/distributed/__init__.py
+0
-0
src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py
...tory/v1/plugins/trainer_plugins/distributed/accelerate.py
+0
-0
src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py
...ctory/v1/plugins/trainer_plugins/distributed/deepspeed.py
+0
-0
src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py
...mafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py
+478
-0
src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py
...lamafactory/v1/plugins/trainer_plugins/distributed/hub.py
+61
-0
src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py
src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py
+19
-0
src/llamafactory/v1/plugins/trainer_plugins/optimizer.py
src/llamafactory/v1/plugins/trainer_plugins/optimizer.py
+19
-0
src/llamafactory/v1/samplers/cli_sampler.py
src/llamafactory/v1/samplers/cli_sampler.py
+125
-0
src/llamafactory/v1/trainers/__init__.py
src/llamafactory/v1/trainers/__init__.py
+0
-0
src/llamafactory/v1/trainers/dpo_trainer.py
src/llamafactory/v1/trainers/dpo_trainer.py
+0
-0
src/llamafactory/v1/trainers/rm_trainer.py
src/llamafactory/v1/trainers/rm_trainer.py
+0
-0
src/llamafactory/v1/trainers/sft_trainer.py
src/llamafactory/v1/trainers/sft_trainer.py
+38
-0
src/llamafactory/v1/utils/__init__.py
src/llamafactory/v1/utils/__init__.py
+0
-0
src/llamafactory/v1/utils/batching_queue.py
src/llamafactory/v1/utils/batching_queue.py
+220
-0
src/llamafactory/v1/utils/constants.py
src/llamafactory/v1/utils/constants.py
+13
-0
src/llamafactory/v1/utils/dtype.py
src/llamafactory/v1/utils/dtype.py
+91
-0
src/llamafactory/v1/utils/env.py
src/llamafactory/v1/utils/env.py
+30
-0
src/llamafactory/v1/utils/helper.py
src/llamafactory/v1/utils/helper.py
+103
-0
src/llamafactory/v1/utils/logging.py
src/llamafactory/v1/utils/logging.py
+123
-0
src/llamafactory/v1/utils/objects.py
src/llamafactory/v1/utils/objects.py
+67
-0
No files found.
src/llamafactory/v1/plugins/trainer_plugins/distributed/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/trainer_plugins/distributed/deepspeed.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
gc
import
os
import
torch
import
torch.nn
as
nn
from
peft.tuners.lora
import
LoraLayer
from
torch.distributed.checkpoint.state_dict
import
StateDictOptions
,
get_model_state_dict
,
set_model_state_dict
from
torch.distributed.fsdp
import
(
CPUOffloadPolicy
,
MixedPrecisionPolicy
,
fully_shard
,
)
from
....accelerator.helper
import
get_current_accelerator
from
....accelerator.interface
import
DistributedInterface
from
....utils.logging
import
get_logger
from
....utils.types
import
HFModel
,
Processor
logger
=
get_logger
(
__name__
)
def
get_transformer_layer_cls
(
model
:
HFModel
)
->
type
[
nn
.
Module
]
|
None
:
no_split_modules
=
getattr
(
model
,
"_no_split_modules"
,
None
)
if
no_split_modules
:
if
isinstance
(
no_split_modules
,
(
list
,
tuple
)):
for
name
,
module
in
model
.
named_modules
():
for
cls_name
in
no_split_modules
:
if
module
.
__class__
.
__name__
==
cls_name
:
return
module
.
__class__
if
hasattr
(
model
,
"model"
)
and
hasattr
(
model
.
model
,
"layers"
):
return
type
(
model
.
model
.
layers
[
0
])
if
hasattr
(
model
,
"layers"
):
return
type
(
model
.
layers
[
0
])
return
None
def
save_model
(
model
:
HFModel
,
output_dir
:
str
,
processor
:
Processor
)
->
None
:
if
DistributedInterface
().
get_rank
()
==
0
:
logger
.
info
(
"Gathering state dict for saving..."
)
options
=
StateDictOptions
(
full_state_dict
=
True
,
cpu_offload
=
True
)
state_dict
=
get_model_state_dict
(
model
,
options
=
options
)
if
DistributedInterface
().
get_rank
()
==
0
:
model_to_save
=
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
model_to_save
.
save_pretrained
(
output_dir
,
state_dict
=
state_dict
,
max_shard_size
=
"4GB"
)
processor
.
save_pretrained
(
output_dir
,
max_shard_size
=
"4GB"
)
logger
.
info
(
f
"Model saved to
{
output_dir
}
"
)
class
FSDP2Engine
:
def
__init__
(
self
,
dist_config
:
dict
):
self
.
dist_interface
=
DistributedInterface
()
self
.
rank
=
self
.
dist_interface
.
get_rank
()
self
.
local_rank
=
self
.
dist_interface
.
get_local_rank
()
self
.
world_size
=
self
.
dist_interface
.
get_world_size
()
self
.
mixed_precision
=
dist_config
.
get
(
"mixed_precision"
,
"bf16"
)
self
.
reshard_after_forward
=
dist_config
.
get
(
"reshard_after_forward"
,
True
)
self
.
offload_params
=
dist_config
.
get
(
"offload_params"
,
False
)
self
.
pin_memory
=
dist_config
.
get
(
"pin_memory"
,
True
)
self
.
dcp_path
=
dist_config
.
get
(
"dcp_path"
,
None
)
self
.
device_mesh
=
self
.
dist_interface
.
data_device_mesh
if
self
.
device_mesh
is
None
:
logger
.
warning
(
"Device Mesh not found in DistributedInterface. FSDP2 might fail if not running in distributed mode."
)
if
self
.
device_mesh
is
not
None
:
try
:
self
.
fsdp_mesh
=
self
.
device_mesh
[
"dp"
]
except
Exception
:
self
.
fsdp_mesh
=
self
.
device_mesh
logger
.
info
(
f
"Using Device Mesh:
{
self
.
fsdp_mesh
}
"
)
else
:
self
.
fsdp_mesh
=
None
def
get_mp_policy
(
self
)
->
MixedPrecisionPolicy
:
if
self
.
mixed_precision
==
"bf16"
:
param_dtype
=
torch
.
bfloat16
reduce_dtype
=
torch
.
float32
elif
self
.
mixed_precision
==
"fp16"
:
param_dtype
=
torch
.
float16
reduce_dtype
=
torch
.
float32
else
:
param_dtype
=
torch
.
float32
reduce_dtype
=
torch
.
float32
return
MixedPrecisionPolicy
(
param_dtype
=
param_dtype
,
reduce_dtype
=
reduce_dtype
,
cast_forward_inputs
=
True
,
)
def
is_lora_module_wrap
(
self
,
model
)
->
bool
:
return
any
(
isinstance
(
module
,
LoraLayer
)
for
module
in
model
.
modules
())
def
prepare_model
(
self
,
model
:
HFModel
)
->
HFModel
:
if
self
.
fsdp_mesh
is
None
:
logger
.
warning
(
"No FSDP Mesh available, skipping FSDP wrapping."
)
return
model
mp_policy
=
self
.
get_mp_policy
()
layer_cls
=
get_transformer_layer_cls
(
model
)
if
layer_cls
is
None
:
logger
.
warning
(
"Could not identify Transformer Layer class, applying FSDP to the whole model structure only."
)
transformer_layer_cls_to_wrap
=
set
()
else
:
logger
.
info
(
f
"Applying per-layer FSDP to
{
layer_cls
.
__name__
}
"
)
transformer_layer_cls_to_wrap
=
{
layer_cls
}
if
self
.
is_lora_module_wrap
(
model
):
lora_modules
=
[]
for
module
in
model
.
modules
():
if
len
(
list
(
module
.
children
()))
!=
0
:
continue
if
any
(
param
.
requires_grad
for
param
in
module
.
parameters
(
recurse
=
False
)):
lora_modules
.
append
(
module
)
for
module
in
lora_modules
:
fully_shard
(
module
,
mesh
=
self
.
fsdp_mesh
,
reshard_after_forward
=
self
.
reshard_after_forward
,
mp_policy
=
mp_policy
,
offload_policy
=
CPUOffloadPolicy
(
pin_memory
=
self
.
pin_memory
)
if
self
.
offload_params
else
None
,
)
logger
.
info
(
"Applying FSDP wrap for LoRA layer separately."
)
for
name
,
module
in
model
.
named_modules
():
should_wrap
=
False
if
type
(
module
)
in
transformer_layer_cls_to_wrap
:
should_wrap
=
True
elif
isinstance
(
module
,
nn
.
Embedding
):
if
not
getattr
(
model
.
config
,
"tie_word_embeddings"
,
True
):
should_wrap
=
True
if
should_wrap
:
fully_shard
(
module
,
mesh
=
self
.
fsdp_mesh
,
reshard_after_forward
=
self
.
reshard_after_forward
,
mp_policy
=
mp_policy
,
offload_policy
=
CPUOffloadPolicy
(
pin_memory
=
self
.
pin_memory
)
if
self
.
offload_params
else
None
,
)
# BaseTrainer is the single source of truth for gradient checkpointing.
# FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
if
getattr
(
model
,
"is_gradient_checkpointing"
,
False
):
if
self
.
rank
==
0
:
logger
.
info
(
"Gradient checkpointing is enabled. Applying FSDP2 input grad preparation."
)
if
hasattr
(
model
,
"enable_input_require_grads"
):
model
.
enable_input_require_grads
()
else
:
def
make_inputs_require_grad
(
module
,
input
,
output
):
output
.
requires_grad_
(
True
)
model
.
get_input_embeddings
().
register_forward_hook
(
make_inputs_require_grad
)
fully_shard
(
model
,
mesh
=
self
.
fsdp_mesh
,
reshard_after_forward
=
self
.
reshard_after_forward
,
mp_policy
=
mp_policy
,
offload_policy
=
CPUOffloadPolicy
(
pin_memory
=
self
.
pin_memory
)
if
self
.
offload_params
else
None
,
)
return
model
@
torch
.
no_grad
()
def
materialize_and_load
(
self
,
model
:
HFModel
,
hf_model_path
:
str
,
dcp_path
:
str
=
None
):
if
self
.
rank
==
0
:
logger
.
info
(
"Materializing sharded model params..."
)
device
=
get_current_accelerator
()
model
.
to_empty
(
device
=
device
)
if
dcp_path
and
os
.
path
.
exists
(
dcp_path
):
if
self
.
rank
==
0
:
logger
.
info
(
f
"DCP path found at
{
dcp_path
}
. Using efficient Sharded Loading (DCP Load)."
)
self
.
_load_from_dcp
(
model
,
dcp_path
)
else
:
if
self
.
rank
==
0
:
if
dcp_path
:
logger
.
warning
(
f
"DCP path
{
dcp_path
}
not found."
)
logger
.
info
(
"Using HF Meta Loading (Chunk Load)."
)
self
.
_load_weights_from_hf_checkpoint
(
model
,
hf_model_path
)
return
model
def
_save_non_persistent_buffers
(
self
,
model
:
HFModel
)
->
dict
:
"""Save non-persistent buffers, such as inv_freq."""
saved
=
{}
for
mod_name
,
module
in
model
.
named_modules
():
for
buf_name
in
module
.
_non_persistent_buffers_set
:
fqn
=
f
"
{
mod_name
}
.
{
buf_name
}
"
if
mod_name
else
buf_name
buf
=
getattr
(
module
,
buf_name
,
None
)
if
buf
is
not
None
:
saved
[
fqn
]
=
copy
.
deepcopy
(
buf
)
if
self
.
rank
==
0
and
saved
:
logger
.
info
(
f
"Saved
{
len
(
saved
)
}
non-persistent buffers"
)
return
saved
def
_restore_non_persistent_buffers
(
self
,
model
:
HFModel
,
saved_buffers
:
dict
):
"""Register saved non-persistent buffers to model."""
if
not
saved_buffers
:
return
device
=
get_current_accelerator
()
for
fqn
,
buf
in
saved_buffers
.
items
():
buf
=
buf
.
to
(
device
)
if
"."
in
fqn
:
parent_fqn
,
buf_name
=
fqn
.
rsplit
(
"."
,
1
)
parent_module
=
model
.
get_submodule
(
parent_fqn
)
else
:
buf_name
=
fqn
parent_module
=
model
parent_module
.
register_buffer
(
buf_name
,
buf
,
persistent
=
False
)
if
self
.
rank
==
0
:
logger
.
info
(
f
"Restored
{
len
(
saved_buffers
)
}
non-persistent buffers"
)
def
shard_model
(
self
,
model
:
HFModel
)
->
HFModel
:
if
model
.
device
.
type
==
"meta"
:
non_persistent_buffers
=
self
.
_save_non_persistent_buffers
(
model
)
if
getattr
(
model
.
config
,
"tie_word_embeddings"
,
None
):
model
.
tie_weights
()
model
=
self
.
prepare_model
(
model
)
model
=
self
.
materialize_and_load
(
model
,
hf_model_path
=
model
.
config
.
name_or_path
,
dcp_path
=
self
.
dcp_path
)
# fix tied broken for no-fsdp-wrap case
if
getattr
(
model
.
config
,
"tie_word_embeddings"
,
None
):
model
.
tie_weights
()
self
.
_restore_non_persistent_buffers
(
model
,
non_persistent_buffers
)
else
:
model
=
self
.
prepare_model
(
model
)
return
model
def
_load_from_dcp
(
self
,
model
:
HFModel
,
dcp_path
:
str
):
import
torch.distributed.checkpoint
as
dcp
try
:
if
self
.
rank
==
0
:
logger
.
info
(
f
"Loading distributed checkpoint from
{
dcp_path
}
..."
)
options
=
StateDictOptions
(
full_state_dict
=
False
,
cpu_offload
=
True
)
local_state_dict
=
get_model_state_dict
(
model
,
options
=
options
)
dcp
.
load
(
state_dict
=
local_state_dict
,
checkpoint_id
=
dcp_path
)
set_model_state_dict
(
model
,
local_state_dict
,
options
=
options
)
if
self
.
rank
==
0
:
logger
.
info
(
"DCP weights loaded successfully."
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to load from DCP:
{
e
}
"
)
raise
e
def
_load_weights_from_hf_checkpoint
(
self
,
model
:
HFModel
,
hf_model_path
:
str
):
import
glob
import
json
hf_model_path
=
self
.
_resolve_hf_checkpoint_dir
(
hf_model_path
)
if
self
.
rank
==
0
:
logger
.
info
(
f
"Loading weights from
{
hf_model_path
}
..."
)
index_file
=
os
.
path
.
join
(
hf_model_path
,
"model.safetensors.index.json"
)
is_safetensors
=
True
checkpoint_files
=
[]
if
os
.
path
.
exists
(
index_file
):
with
open
(
index_file
)
as
f
:
index
=
json
.
load
(
f
)
checkpoint_files
=
sorted
(
set
(
index
[
"weight_map"
].
values
()))
checkpoint_files
=
[
os
.
path
.
join
(
hf_model_path
,
f
)
for
f
in
checkpoint_files
]
elif
os
.
path
.
exists
(
os
.
path
.
join
(
hf_model_path
,
"model.safetensors"
)):
checkpoint_files
=
[
os
.
path
.
join
(
hf_model_path
,
"model.safetensors"
)]
else
:
is_safetensors
=
False
index_file
=
os
.
path
.
join
(
hf_model_path
,
"pytorch_model.bin.index.json"
)
if
os
.
path
.
exists
(
index_file
):
with
open
(
index_file
)
as
f
:
index
=
json
.
load
(
f
)
checkpoint_files
=
sorted
(
set
(
index
[
"weight_map"
].
values
()))
checkpoint_files
=
[
os
.
path
.
join
(
hf_model_path
,
f
)
for
f
in
checkpoint_files
]
elif
os
.
path
.
exists
(
os
.
path
.
join
(
hf_model_path
,
"pytorch_model.bin"
)):
checkpoint_files
=
[
os
.
path
.
join
(
hf_model_path
,
"pytorch_model.bin"
)]
else
:
checkpoint_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
hf_model_path
,
"*.safetensors"
)))
if
checkpoint_files
:
is_safetensors
=
True
else
:
checkpoint_files
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
hf_model_path
,
"*.bin"
)))
if
not
checkpoint_files
:
raise
ValueError
(
f
"No checkpoint files found in
{
hf_model_path
}
"
)
param_map
=
dict
(
model
.
named_parameters
())
total_files
=
len
(
checkpoint_files
)
for
i
,
ckpt_file
in
enumerate
(
checkpoint_files
):
if
self
.
rank
==
0
:
logger
.
info
(
f
"[
{
i
+
1
}
/
{
total_files
}
] Loading
{
os
.
path
.
basename
(
ckpt_file
)
}
..."
)
if
is_safetensors
:
from
safetensors
import
safe_open
with
safe_open
(
ckpt_file
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
if
key
in
param_map
:
tensor
=
f
.
get_tensor
(
key
)
self
.
_copy_weights
(
param_map
[
key
],
tensor
)
else
:
state_dict
=
torch
.
load
(
ckpt_file
,
map_location
=
"cpu"
)
for
key
,
tensor
in
state_dict
.
items
():
if
key
in
param_map
:
self
.
_copy_weights
(
param_map
[
key
],
tensor
)
del
state_dict
gc
.
collect
()
def
_resolve_hf_checkpoint_dir
(
self
,
hf_model_path
:
str
)
->
str
:
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.
- If `hf_model_path` is an existing directory, return it.
- If it's a file path, return its parent directory.
- Otherwise treat it as a Hugging Face Hub repo id and download/resolve to the local cache dir.
"""
if
not
hf_model_path
:
return
hf_model_path
# Local directory or file path.
if
os
.
path
.
isdir
(
hf_model_path
):
return
hf_model_path
if
os
.
path
.
isfile
(
hf_model_path
):
return
os
.
path
.
dirname
(
hf_model_path
)
# HuggingFace Hub repo id: snapshot to local cache so we can glob/index files.
try
:
from
huggingface_hub
import
snapshot_download
except
ImportError
as
e
:
raise
ValueError
(
f
"hf_model_path='
{
hf_model_path
}
' does not exist locally and huggingface_hub is not available "
f
"to download it. Please provide a local model directory or install huggingface_hub. Error:
{
e
}
"
)
from
e
revision
=
os
.
getenv
(
"HF_REVISION"
)
offline
=
os
.
getenv
(
"HF_HUB_OFFLINE"
)
==
"1"
or
os
.
getenv
(
"TRANSFORMERS_OFFLINE"
)
==
"1"
# In distributed runs, let rank0 download first to avoid N-way concurrent downloads.
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
if
self
.
rank
==
0
:
local_dir
=
snapshot_download
(
repo_id
=
hf_model_path
,
revision
=
revision
,
local_files_only
=
offline
,
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
,
"*.index.json"
,
"model.safetensors"
,
"model.safetensors.index.json"
,
"pytorch_model.bin"
,
"pytorch_model.bin.index.json"
,
"config.json"
,
],
)
logger
.
info
(
f
"Resolved HF repo id '
{
hf_model_path
}
' to local dir:
{
local_dir
}
"
)
torch
.
distributed
.
barrier
()
if
self
.
rank
!=
0
:
local_dir
=
snapshot_download
(
repo_id
=
hf_model_path
,
revision
=
revision
,
local_files_only
=
True
,
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
,
"*.index.json"
,
"model.safetensors"
,
"model.safetensors.index.json"
,
"pytorch_model.bin"
,
"pytorch_model.bin.index.json"
,
"config.json"
,
],
)
return
local_dir
local_dir
=
snapshot_download
(
repo_id
=
hf_model_path
,
revision
=
revision
,
local_files_only
=
offline
,
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
,
"*.index.json"
,
"model.safetensors"
,
"model.safetensors.index.json"
,
"pytorch_model.bin"
,
"pytorch_model.bin.index.json"
,
"config.json"
,
],
)
if
self
.
rank
==
0
:
logger
.
info
(
f
"Resolved HF repo id '
{
hf_model_path
}
' to local dir:
{
local_dir
}
"
)
return
local_dir
def
_copy_weights
(
self
,
param
,
loaded_tensor
):
from
torch.distributed._tensor
import
DTensor
,
Shard
if
loaded_tensor
.
dtype
!=
param
.
dtype
:
loaded_tensor
=
loaded_tensor
.
to
(
param
.
dtype
)
if
isinstance
(
param
,
DTensor
):
shard_placement
=
None
mesh_dim
=
-
1
for
i
,
placement
in
enumerate
(
param
.
placements
):
if
isinstance
(
placement
,
Shard
):
shard_placement
=
placement
mesh_dim
=
i
break
local_tensor
=
param
.
to_local
()
if
shard_placement
is
None
:
local_tensor
.
copy_
(
loaded_tensor
)
else
:
dim
=
shard_placement
.
dim
mesh
=
param
.
device_mesh
my_coordinate
=
mesh
.
get_coordinate
()
if
my_coordinate
is
None
:
return
rank_in_dim
=
my_coordinate
[
mesh_dim
]
world_size_in_dim
=
mesh
.
size
(
mesh_dim
)
full_size
=
param
.
shape
[
dim
]
chunk_size
=
(
full_size
+
world_size_in_dim
-
1
)
//
world_size_in_dim
start
=
rank_in_dim
*
chunk_size
end
=
min
(
start
+
chunk_size
,
full_size
)
if
start
>=
full_size
:
return
sliced_tensor
=
loaded_tensor
.
narrow
(
dim
,
start
,
end
-
start
)
slices
=
[
slice
(
None
)]
*
local_tensor
.
ndim
slices
[
dim
]
=
slice
(
0
,
sliced_tensor
.
shape
[
dim
])
local_tensor
[
tuple
(
slices
)].
copy_
(
sliced_tensor
)
else
:
param
.
data
.
copy_
(
loaded_tensor
)
src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
from
....config.arg_utils
import
PluginConfig
from
....utils.plugin
import
BasePlugin
if
TYPE_CHECKING
:
from
....utils.types
import
HFModel
,
Processor
class
DistributedPlugin
(
BasePlugin
):
def
__call__
(
self
,
model
:
HFModel
,
dist_config
:
PluginConfig
,
**
kwargs
)
->
HFModel
:
return
super
().
__call__
(
model
,
dist_config
,
**
kwargs
)
@
DistributedPlugin
(
"fsdp2"
).
register
()
def
shard_model_fsdp2
(
model
:
HFModel
,
dist_config
:
PluginConfig
,
**
kwargs
)
->
HFModel
:
from
.fsdp2
import
FSDP2Engine
return
FSDP2Engine
(
dist_config
).
shard_model
(
model
)
@
DistributedPlugin
(
"fsdp2"
).
register
(
"save_model"
)
def
save_model_fsdp2
(
model
:
HFModel
,
output_dir
:
str
,
processor
:
Processor
)
->
None
:
from
.fsdp2
import
save_model
return
save_model
(
model
,
output_dir
,
processor
)
@
DistributedPlugin
(
"deepspeed"
).
register
()
def
shard_model_deepspeed
(
model
:
HFModel
,
dist_config
:
PluginConfig
,
**
kwargs
)
->
HFModel
:
from
.deepspeed
import
DeepSpeedEngine
return
DeepSpeedEngine
(
dist_config
,
num_micro_batch
=
kwargs
.
get
(
"num_micro_batch"
),
micro_batch_size
=
kwargs
.
get
(
"micro_batch_size"
),
).
shard_model
(
model
)
@
DistributedPlugin
(
"deepspeed"
).
register
(
"save_model"
)
def
save_model_deepspeed
(
model
:
HFModel
,
output_dir
:
str
,
processor
:
Processor
)
->
None
:
from
.deepspeed
import
save_model
return
save_model
(
model
,
output_dir
,
processor
)
src/llamafactory/v1/plugins/trainer_plugins/lr_scheduler.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
...utils.plugin
import
BasePlugin
class
LRSchedulerPlugin
(
BasePlugin
):
pass
src/llamafactory/v1/plugins/trainer_plugins/optimizer.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
...utils.plugin
import
BasePlugin
class
OptimizerPlugin
(
BasePlugin
):
pass
src/llamafactory/v1/samplers/cli_sampler.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
os
from
collections.abc
import
Generator
from
threading
import
Thread
from
..config
import
InputArgument
,
ModelArguments
,
SampleArguments
,
SampleBackend
,
get_args
from
..core.base_sampler
import
BaseSampler
from
..core.data_engine
import
DataEngine
from
..core.model_engine
import
ModelEngine
from
..core.utils.rendering
import
Renderer
from
..utils.types
import
HFModel
,
Message
,
Sample
,
TorchDataset
class
SyncSampler
(
BaseSampler
):
def
__init__
(
self
,
args
:
SampleArguments
,
model_args
:
ModelArguments
,
model
:
HFModel
,
renderer
:
Renderer
,
)
->
None
:
def
_start_background_loop
(
loop
:
asyncio
.
AbstractEventLoop
)
->
None
:
asyncio
.
set_event_loop
(
loop
)
loop
.
run_forever
()
super
().
__init__
(
args
,
model_args
,
model
,
renderer
)
self
.
_loop
=
asyncio
.
new_event_loop
()
self
.
_thread
=
Thread
(
target
=
_start_background_loop
,
args
=
(
self
.
_loop
,),
daemon
=
True
)
self
.
_thread
.
start
()
def
generate
(
self
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
)
->
Generator
[
str
,
None
,
None
]:
"""Generate tokens synchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
generator
=
super
().
generate
(
messages
,
tools
)
while
True
:
try
:
token
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
).
result
()
yield
token
except
StopAsyncIteration
:
break
def
batch_infer
(
self
,
dataset
:
TorchDataset
)
->
list
[
Sample
]:
"""Batch infer samples synchronously.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return
asyncio
.
run_coroutine_threadsafe
(
super
().
batch_infer
(
dataset
),
self
.
_loop
).
result
()
def
run_chat
(
args
:
InputArgument
=
None
):
model_args
,
data_args
,
_
,
sample_args
=
get_args
(
args
)
if
sample_args
.
sample_backend
!=
SampleBackend
.
HF
:
model_args
.
init_plugin
=
{
"name"
:
"init_on_meta"
}
model_engine
=
ModelEngine
(
model_args
)
sampler
=
SyncSampler
(
sample_args
,
model_args
,
model_engine
.
model
,
model_engine
.
renderer
)
if
data_args
.
train_dataset
is
not
None
:
dataset
=
DataEngine
(
data_args
.
train_dataset
)
sampler
.
batch_infer
(
dataset
)
else
:
if
os
.
name
!=
"nt"
:
try
:
import
readline
# noqa: F401
except
ImportError
:
print
(
"Install `readline` for a better experience."
)
messages
=
[]
print
(
"Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application."
)
while
True
:
try
:
query
=
input
(
"
\n
User: "
)
except
UnicodeDecodeError
:
print
(
"Detected decoding error at the inputs, please set the terminal encoding to utf-8."
)
continue
except
Exception
:
raise
if
query
.
strip
()
==
"exit"
:
break
if
query
.
strip
()
==
"clear"
:
messages
=
[]
print
(
"History has been removed."
)
continue
messages
.
append
({
"role"
:
"user"
,
"content"
:
[{
"type"
:
"text"
,
"value"
:
query
}]})
print
(
"Assistant: "
,
end
=
""
,
flush
=
True
)
response
=
""
for
new_text
in
sampler
.
generate
(
messages
):
print
(
new_text
,
end
=
""
,
flush
=
True
)
response
+=
new_text
print
()
messages
.
append
(
model_engine
.
renderer
.
parse_message
(
response
))
if
__name__
==
"__main__"
:
run_chat
()
src/llamafactory/v1/trainers/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/trainers/dpo_trainer.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/trainers/rm_trainer.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/trainers/sft_trainer.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
..accelerator.interface
import
DistributedInterface
from
..config.arg_parser
import
get_args
from
..core.base_trainer
import
BaseTrainer
from
..core.data_engine
import
DataEngine
from
..core.model_loader
import
ModelLoader
class
SFTTrainer
(
BaseTrainer
):
pass
def
run_sft
(
user_args
):
model_args
,
data_args
,
training_args
,
_
=
get_args
(
user_args
)
DistributedInterface
(
training_args
.
dist_config
)
data_engine
=
DataEngine
(
data_args
)
model_loader
=
ModelLoader
(
model_args
)
trainer
=
SFTTrainer
(
args
=
training_args
,
model
=
model_loader
.
model
,
processor
=
model_loader
.
processor
,
dataset
=
data_engine
,
)
trainer
.
fit
()
src/llamafactory/v1/utils/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/utils/batching_queue.py
0 → 100644
View file @
ca625f43
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
,
abstractmethod
class
DynamicBatchSizeBuffer
:
"""A buffer to store samples for dynamic batch size."""
def
__init__
(
self
):
self
.
_buffer
:
list
[
dict
[
str
,
any
]]
=
[]
self
.
_buffer_sample_lengths
:
list
[
int
]
=
[]
self
.
_deleted_indices
:
set
[
int
]
=
set
()
self
.
_current_index
:
int
=
0
self
.
_total_token_count
:
int
=
0
def
append
(
self
,
item
:
dict
[
str
,
any
])
->
None
:
"""Append a sample to the buffer.
Args:
item: A sample to append to the buffer.
The sample should be a dict with the following keys:
- input_ids: torch.Tensor of shape (seq_len, )
- attention_mask: torch.Tensor of shape (seq_len, )
"""
self
.
_buffer
.
append
(
item
)
sample_length
=
int
(
item
[
"attention_mask"
].
sum
().
item
())
self
.
_buffer_sample_lengths
.
append
(
sample_length
)
self
.
_total_token_count
+=
sample_length
def
get_samples
(
self
,
max_tokens_per_iteration
:
int
,
force
:
bool
=
True
)
->
list
[
dict
[
str
,
any
]]:
"""Get samples from the buffer that fit within the token budget.
Args:
max_tokens_per_iteration: Maximum number of tokens to retrieve.
force: If True, the first available sample will be returned even
if it exceeds the token budget.
Returns:
A list of samples that fit within the token budget.
Raises:
AssertionError: If no samples are found (should not happen in normal operation).
"""
cum_seq_len
=
0
samples
=
[]
while
self
.
_current_index
<
len
(
self
.
_buffer
)
and
cum_seq_len
<
max_tokens_per_iteration
:
if
self
.
_current_index
in
self
.
_deleted_indices
:
self
.
_current_index
+=
1
continue
seq_len
=
self
.
_buffer_sample_lengths
[
self
.
_current_index
]
remaining_tokens
=
max_tokens_per_iteration
-
cum_seq_len
# Check if we can add this sample
can_add
=
(
force
and
cum_seq_len
==
0
)
or
(
seq_len
<=
remaining_tokens
)
if
can_add
:
cum_seq_len
+=
seq_len
samples
.
append
(
self
.
_buffer
[
self
.
_current_index
])
self
.
_deleted_indices
.
add
(
self
.
_current_index
)
self
.
_current_index
+=
1
assert
len
(
samples
)
>
0
,
"No samples found in buffer"
return
samples
def
__len__
(
self
)
->
int
:
"""Return the number of samples in the buffer."""
return
len
(
self
.
_buffer
)
@
property
def
total_token_count
(
self
)
->
int
:
"""Return the total number of tokens in the buffer."""
return
self
.
_total_token_count
def
flush
(
self
)
->
None
:
tokens_to_remove
=
sum
(
self
.
_buffer_sample_lengths
[
idx
]
for
idx
in
self
.
_deleted_indices
)
self
.
_total_token_count
-=
tokens_to_remove
buffer_length
=
len
(
self
.
_buffer
)
self
.
_buffer
=
[
self
.
_buffer
[
idx
]
for
idx
in
range
(
buffer_length
)
if
idx
not
in
self
.
_deleted_indices
]
self
.
_buffer_sample_lengths
=
[
self
.
_buffer_sample_lengths
[
idx
]
for
idx
in
range
(
buffer_length
)
if
idx
not
in
self
.
_deleted_indices
]
self
.
_current_index
=
0
self
.
_deleted_indices
.
clear
()
class
BaseBatchingQueue
(
ABC
):
"""Base class for batching queue."""
@
abstractmethod
def
is_full_filled
(
self
)
->
bool
:
raise
NotImplementedError
(
"Subclasses must implement `is_full_filled`"
)
@
abstractmethod
def
put_item
(
self
,
item
:
dict
[
str
,
any
])
->
None
:
raise
NotImplementedError
(
"Subclasses must implement `put_item`"
)
@
abstractmethod
def
get_micro_batch
(
self
,
step
:
int
)
->
list
[
dict
[
str
,
any
]]:
raise
NotImplementedError
(
"Subclasses must implement `get_micro_batch`"
)
@
abstractmethod
def
empty
(
self
)
->
bool
:
raise
NotImplementedError
(
"Subclasses must implement `empty`"
)
class
IdentityPacker
:
def
__init__
(
self
,
token_micro_bsz
,
bsz_warmup_steps
,
bsz_warmup_init_mbtoken
):
self
.
token_micro_bsz
=
token_micro_bsz
self
.
bsz_warmup_steps
=
bsz_warmup_steps
self
.
bsz_warmup_init_mbtoken
=
bsz_warmup_init_mbtoken
def
__call__
(
self
,
samples
):
return
samples
def
get_token_num_to_request
(
self
,
cur_step
,
warmup
):
return
(
(
self
.
token_micro_bsz
-
self
.
bsz_warmup_init_mbtoken
)
*
cur_step
//
self
.
bsz_warmup_steps
+
self
.
bsz_warmup_init_mbtoken
if
warmup
else
self
.
token_micro_bsz
)
class
TextBatchingQueue
(
BaseBatchingQueue
):
"""Batching text queue for text data."""
def
__init__
(
self
,
token_micro_bsz
,
buffer_size
:
int
=
500
,
bsz_warmup_steps
:
int
=
-
1
,
bsz_warmup_init_mbtoken
:
int
=
200
,
)
->
None
:
super
().
__init__
()
self
.
_step
=
0
self
.
token_micro_bsz
=
token_micro_bsz
self
.
bsz_warmup_steps
=
bsz_warmup_steps
self
.
buffer_size
=
buffer_size
# minimum samples in buffer
self
.
buffer
=
DynamicBatchSizeBuffer
()
self
.
bsz_warmup_init_mbtoken
=
bsz_warmup_init_mbtoken
# training warmup args
assert
self
.
bsz_warmup_init_mbtoken
>=
0
self
.
packer
=
IdentityPacker
(
token_micro_bsz
=
token_micro_bsz
,
bsz_warmup_steps
=
bsz_warmup_steps
,
bsz_warmup_init_mbtoken
=
bsz_warmup_init_mbtoken
,
)
def
is_full_filled
(
self
)
->
bool
:
return
len
(
self
.
buffer
)
>=
self
.
buffer_size
and
self
.
buffer
.
total_token_count
>=
self
.
token_micro_bsz
def
put_item
(
self
,
item
:
dict
[
str
,
any
]):
if
len
(
item
[
"input_ids"
])
==
1
:
print
(
"WARNING: EMPTY STRING."
)
return
self
.
buffer
.
append
(
item
)
def
get_token_num_to_request
(
self
):
if
self
.
packer
is
not
None
:
warmup
=
self
.
_step
<=
self
.
bsz_warmup_steps
and
self
.
bsz_warmup_steps
>
0
return
self
.
packer
.
get_token_num_to_request
(
self
.
_step
,
warmup
=
warmup
)
else
:
return
self
.
get_cur_token_micro_bsz
()
def
get_cur_token_micro_bsz
(
self
):
warmup
=
self
.
_step
<=
self
.
bsz_warmup_steps
and
self
.
bsz_warmup_steps
>
0
if
warmup
:
return
(
self
.
token_micro_bsz
-
self
.
bsz_warmup_init_mbtoken
)
*
self
.
_step
//
self
.
bsz_warmup_steps
+
self
.
bsz_warmup_init_mbtoken
else
:
return
self
.
token_micro_bsz
def
get_micro_batch
(
self
,
step
)
->
any
:
"""Get a micro batch from the buffer according to the current step.
Args:
step: the current step.
Returns:
data: a list of samples.
"""
self
.
_step
=
step
n_token_per_iter
=
self
.
get_token_num_to_request
()
cur_token_micro_bsz
=
self
.
get_cur_token_micro_bsz
()
assert
cur_token_micro_bsz
%
n_token_per_iter
==
0
,
(
"The token num to get for each request should be divisible by token micro bsz."
)
n_iter
=
int
(
cur_token_micro_bsz
//
n_token_per_iter
)
data
=
[]
for
_
in
range
(
n_iter
):
samples
=
self
.
buffer
.
get_samples
(
n_token_per_iter
)
if
self
.
packer
:
samples
=
self
.
packer
(
samples
)
# maybe packed into one sample, but wrapped in list.
data
.
extend
(
samples
)
self
.
buffer
.
flush
()
# remove the selected samples.
return
data
def
empty
(
self
)
->
bool
:
return
len
(
self
.
buffer
)
==
0
src/llamafactory/v1/utils/constants.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
src/llamafactory/v1/utils/dtype.py
0 → 100644
View file @
ca625f43
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's verl library.
# https://github.com/volcengine/verl/blob/v0.6.1/verl/utils/torch_dtypes.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
contextlib
import
contextmanager
import
torch
from
transformers.utils
import
is_torch_bf16_available_on_device
,
is_torch_fp16_available_on_device
from
..accelerator.interface
import
DistributedInterface
class
DtypeRegistry
:
HALF_LIST
=
[
"fp16"
,
"float16"
,
"half"
,
torch
.
float16
]
FLOAT_LIST
=
[
"fp32"
,
"float32"
,
"float"
,
torch
.
float32
]
BFLOAT_LIST
=
[
"bf16"
,
"bfloat16"
,
torch
.
bfloat16
]
class
DtypeInterface
:
"""Type of precision used."""
_is_fp16_available
=
is_torch_fp16_available_on_device
(
DistributedInterface
.
current_accelerator
)
_is_bf16_available
=
is_torch_bf16_available_on_device
(
DistributedInterface
.
current_accelerator
)
_is_fp32_available
=
True
@
staticmethod
def
is_available
(
precision
:
str
|
torch
.
dtype
)
->
bool
:
if
precision
in
DtypeRegistry
.
HALF_LIST
:
return
DtypeInterface
.
_is_fp16_available
elif
precision
in
DtypeRegistry
.
FLOAT_LIST
:
return
DtypeInterface
.
_is_fp32_available
elif
precision
in
DtypeRegistry
.
BFLOAT_LIST
:
return
DtypeInterface
.
_is_bf16_available
else
:
raise
RuntimeError
(
f
"Unexpected precision:
{
precision
}
"
)
@
staticmethod
def
is_fp16
(
precision
:
str
|
torch
.
dtype
)
->
bool
:
return
precision
in
DtypeRegistry
.
HALF_LIST
@
staticmethod
def
is_fp32
(
precision
:
str
|
torch
.
dtype
)
->
bool
:
return
precision
in
DtypeRegistry
.
FLOAT_LIST
@
staticmethod
def
is_bf16
(
precision
:
str
|
torch
.
dtype
)
->
bool
:
return
precision
in
DtypeRegistry
.
BFLOAT_LIST
@
staticmethod
def
to_dtype
(
precision
:
str
|
torch
.
dtype
)
->
torch
.
dtype
:
if
precision
in
DtypeRegistry
.
HALF_LIST
:
return
torch
.
float16
elif
precision
in
DtypeRegistry
.
FLOAT_LIST
:
return
torch
.
float32
elif
precision
in
DtypeRegistry
.
BFLOAT_LIST
:
return
torch
.
bfloat16
else
:
raise
RuntimeError
(
f
"Unexpected precision:
{
precision
}
"
)
@
staticmethod
def
to_str
(
precision
:
torch
.
dtype
)
->
str
:
if
precision
==
torch
.
float16
:
return
"float16"
elif
precision
==
torch
.
float32
:
return
"float32"
elif
precision
==
torch
.
bfloat16
:
return
"bfloat16"
else
:
raise
RuntimeError
(
f
"Unexpected precision:
{
precision
}
"
)
@
contextmanager
def
set_dtype
(
self
,
precision
:
str
|
torch
.
dtype
):
original_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
self
.
to_dtype
(
precision
))
try
:
yield
finally
:
torch
.
set_default_dtype
(
original_dtype
)
src/llamafactory/v1/utils/env.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
socket
def
find_available_port
()
->
int
:
"""Find an available port on the local machine."""
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
bind
((
""
,
0
))
port
=
sock
.
getsockname
()[
1
]
sock
.
close
()
return
port
def
is_env_enabled
(
env_var
:
str
,
default
:
str
=
"0"
)
->
bool
:
"""Check if the environment variable is enabled."""
return
os
.
getenv
(
env_var
,
default
).
lower
()
in
[
"true"
,
"yes"
,
"on"
,
"t"
,
"y"
,
"1"
]
src/llamafactory/v1/utils/helper.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
transformers
import
PreTrainedTokenizer
from
transformers
import
set_seed
as
hf_set_seed
from
..accelerator.interface
import
DistributedInterface
from
.constants
import
IGNORE_INDEX
from
.types
import
BatchInput
,
ModelInput
,
Processor
,
Tensor
def
set_seed
(
seed
:
int
)
->
None
:
"""Set seed for reproducibility.
Args:
seed: Random seed.
"""
hf_set_seed
(
seed
)
def
is_tokenizer
(
processor
:
Processor
)
->
bool
:
"""Check if processor is tokenizer.
Args:
processor: Processor.
Returns:
Whether processor is tokenizer.
"""
return
not
hasattr
(
processor
,
"tokenizer"
)
def
get_tokenizer
(
processor
:
Processor
)
->
PreTrainedTokenizer
:
"""Get tokenizer from processor.
Args:
processor: Processor.
Returns:
Tokenizer.
"""
return
processor
.
tokenizer
if
hasattr
(
processor
,
"tokenizer"
)
else
processor
def
_pad_and_truncate
(
tensor
:
Tensor
,
max_seqlen
:
int
,
pad_value
:
int
=
0
)
->
Tensor
:
if
tensor
.
shape
[
-
1
]
>=
max_seqlen
:
return
tensor
[...,
:
max_seqlen
]
pad_shape
=
list
(
tensor
.
shape
)
pad_shape
[
-
1
]
=
max_seqlen
-
tensor
.
shape
[
-
1
]
pad_tensor
=
torch
.
full
(
pad_shape
,
pad_value
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
return
torch
.
cat
([
tensor
,
pad_tensor
],
dim
=-
1
)
def
pad_and_truncate
(
samples
:
list
[
ModelInput
],
max_seqlen
:
int
)
->
list
[
BatchInput
]:
max_length
=
min
(
max
(
len
(
sample
[
"input_ids"
])
for
sample
in
samples
),
max_seqlen
)
padded_samples
=
[]
for
sample
in
samples
:
padded_sample
=
{}
for
key
,
value
in
sample
.
items
():
if
"label"
in
key
:
pad_value
=
IGNORE_INDEX
else
:
pad_value
=
0
if
not
isinstance
(
value
,
str
):
padded_sample
[
key
]
=
_pad_and_truncate
(
torch
.
tensor
(
value
),
max_length
,
pad_value
)
else
:
padded_sample
[
key
]
=
value
padded_samples
.
append
(
padded_sample
)
return
padded_samples
def
compute_valid_tokens
(
batches
:
list
[
BatchInput
])
->
int
:
"""Compute valid tokens in batches.
Args:
batches: Batches.
Returns:
Number of valid tokens.
"""
device
=
DistributedInterface
().
current_device
return
sum
(
(
batch
[
"labels"
].
to
(
device
,
non_blocking
=
True
)
!=
IGNORE_INDEX
).
sum
().
item
()
for
batch
in
batches
if
"labels"
in
batch
)
src/llamafactory/v1/utils/logging.py
0 → 100644
View file @
ca625f43
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
import
sys
import
threading
from
functools
import
lru_cache
from
typing
import
Optional
_thread_lock
=
threading
.
RLock
()
_default_handler
:
Optional
[
"logging.Handler"
]
=
None
_default_log_level
:
"logging._Level"
=
logging
.
INFO
class
_Logger
(
logging
.
Logger
):
"""A logger that supports rank0 logging."""
def
info_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
info
(
*
args
,
**
kwargs
)
def
warning_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
warning
(
*
args
,
**
kwargs
)
def
warning_rank0_once
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
warning
(
*
args
,
**
kwargs
)
def
_get_default_logging_level
()
->
"logging._Level"
:
"""Return the default logging level."""
env_level_str
=
os
.
getenv
(
"LLAMAFACTORY_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
.
upper
()
in
logging
.
_nameToLevel
:
return
logging
.
_nameToLevel
[
env_level_str
.
upper
()]
else
:
raise
ValueError
(
f
"Unknown logging level:
{
env_level_str
}
."
)
return
_default_log_level
def
_get_library_name
()
->
str
:
return
__name__
.
split
(
"."
)[
0
]
def
_get_library_root_logger
()
->
"_Logger"
:
return
logging
.
getLogger
(
_get_library_name
())
def
_configure_library_root_logger
()
->
None
:
"""Configure root logger using a stdout stream handler with an explicit format."""
global
_default_handler
with
_thread_lock
:
if
_default_handler
:
# already configured
return
formatter
=
logging
.
Formatter
(
fmt
=
"[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
_default_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
_default_handler
.
setFormatter
(
formatter
)
library_root_logger
=
_get_library_root_logger
()
library_root_logger
.
addHandler
(
_default_handler
)
library_root_logger
.
setLevel
(
_get_default_logging_level
())
library_root_logger
.
propagate
=
False
def
get_logger
(
name
:
str
|
None
=
None
)
->
"_Logger"
:
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if
name
is
None
:
name
=
_get_library_name
()
_configure_library_root_logger
()
return
logging
.
getLogger
(
name
)
def
add_handler
(
handler
:
"logging.Handler"
)
->
None
:
"""Add a handler to the root logger."""
_configure_library_root_logger
()
_get_library_root_logger
().
addHandler
(
handler
)
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""Remove a handler to the root logger."""
_configure_library_root_logger
()
_get_library_root_logger
().
removeHandler
(
handler
)
def
info_rank0
(
self
:
"logging.Logger"
,
*
args
,
**
kwargs
)
->
None
:
if
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
==
0
:
self
.
info
(
*
args
,
**
kwargs
)
def
warning_rank0
(
self
:
"logging.Logger"
,
*
args
,
**
kwargs
)
->
None
:
if
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
==
0
:
self
.
warning
(
*
args
,
**
kwargs
)
@
lru_cache
(
None
)
def
warning_rank0_once
(
self
:
"logging.Logger"
,
*
args
,
**
kwargs
)
->
None
:
if
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
==
0
:
self
.
warning
(
*
args
,
**
kwargs
)
logging
.
Logger
.
info_rank0
=
info_rank0
logging
.
Logger
.
warning_rank0
=
warning_rank0
logging
.
Logger
.
warning_rank0_once
=
warning_rank0_once
src/llamafactory/v1/utils/objects.py
0 → 100644
View file @
ca625f43
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/utils/logging.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.types
import
ModelInput
class
StatefulBuffer
:
"""A buffer that stores model inputs."""
def
__init__
(
self
,
max_buffer_size
:
int
=
1_000_000_000
)
->
None
:
self
.
_buffer
:
list
[
ModelInput
]
=
[]
self
.
_buffer_size
:
int
=
0
self
.
_max_buffer_size
:
int
=
max_buffer_size
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_buffer
)
@
property
def
size
(
self
)
->
int
:
return
self
.
_buffer_size
def
put
(
self
,
samples
:
list
[
ModelInput
])
->
None
:
"""Add samples to the buffer."""
num_tokens
=
sum
(
len
(
sample
[
"input_ids"
])
for
sample
in
samples
)
if
self
.
_buffer_size
+
num_tokens
>
self
.
_max_buffer_size
:
raise
ValueError
(
f
"Buffer size exceeds max buffer size
{
self
.
_max_buffer_size
}
."
)
self
.
_buffer
.
extend
(
samples
)
self
.
_buffer_size
+=
num_tokens
def
get
(
self
,
value
:
int
)
->
list
[
ModelInput
]:
"""Get samples from the buffer and remove them."""
samples
=
self
.
_buffer
[:
value
]
self
.
_buffer_size
-=
sum
(
len
(
sample
[
"input_ids"
])
for
sample
in
samples
)
del
self
.
_buffer
[:
value
]
return
samples
def
clear
(
self
)
->
None
:
"""Clear the buffer."""
self
.
_buffer
=
[]
self
.
_buffer_size
=
0
def
state_dict
(
self
)
->
dict
:
"""Returns the state of the buffer."""
return
{
"buffer"
:
self
.
_buffer
,
"buffer_size"
:
self
.
_buffer_size
,
}
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
"""Loads the state into the buffer."""
self
.
_buffer
=
state_dict
[
"buffer"
]
self
.
_buffer_size
=
state_dict
[
"buffer_size"
]
Prev
1
…
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment