Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1466 additions
and
0 deletions
+1466
-0
src/llamafactory/v1/accelerator/__init__.py
src/llamafactory/v1/accelerator/__init__.py
+0
-0
src/llamafactory/v1/accelerator/helper.py
src/llamafactory/v1/accelerator/helper.py
+215
-0
src/llamafactory/v1/accelerator/interface.py
src/llamafactory/v1/accelerator/interface.py
+249
-0
src/llamafactory/v1/accelerator/profiler.py
src/llamafactory/v1/accelerator/profiler.py
+0
-0
src/llamafactory/v1/config/__init__.py
src/llamafactory/v1/config/__init__.py
+0
-0
src/llamafactory/v1/config/arg_parser.py
src/llamafactory/v1/config/arg_parser.py
+80
-0
src/llamafactory/v1/config/arg_utils.py
src/llamafactory/v1/config/arg_utils.py
+95
-0
src/llamafactory/v1/config/data_args.py
src/llamafactory/v1/config/data_args.py
+28
-0
src/llamafactory/v1/config/model_args.py
src/llamafactory/v1/config/model_args.py
+54
-0
src/llamafactory/v1/config/sample_args.py
src/llamafactory/v1/config/sample_args.py
+30
-0
src/llamafactory/v1/config/training_args.py
src/llamafactory/v1/config/training_args.py
+50
-0
src/llamafactory/v1/core/__init__.py
src/llamafactory/v1/core/__init__.py
+0
-0
src/llamafactory/v1/core/base_sampler.py
src/llamafactory/v1/core/base_sampler.py
+67
-0
src/llamafactory/v1/core/base_trainer.py
src/llamafactory/v1/core/base_trainer.py
+58
-0
src/llamafactory/v1/core/chat_sampler.py
src/llamafactory/v1/core/chat_sampler.py
+44
-0
src/llamafactory/v1/core/data_engine.py
src/llamafactory/v1/core/data_engine.py
+187
-0
src/llamafactory/v1/core/model_engine.py
src/llamafactory/v1/core/model_engine.py
+174
-0
src/llamafactory/v1/core/model_loader.py
src/llamafactory/v1/core/model_loader.py
+135
-0
src/llamafactory/v1/core/trainer_utils/__init__.py
src/llamafactory/v1/core/trainer_utils/__init__.py
+0
-0
src/llamafactory/v1/core/trainer_utils/callback.py
src/llamafactory/v1/core/trainer_utils/callback.py
+0
-0
No files found.
src/llamafactory/v1/accelerator/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/accelerator/helper.py
0 → 100644
View file @
ca625f43
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/utils/dist_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions used by the distributed interface.
Including:
- Environment info (rank, world_size, local_rank, etc.)
- Accelerator info (device type, device count, etc.)
- Collective communication operations (all_gather, all_reduce, broadcast)
- Synchronize processes and ensure main-process-first execution order
"""
import
os
from
collections.abc
import
Callable
from
contextlib
import
contextmanager
from
enum
import
Enum
,
unique
from
functools
import
lru_cache
,
wraps
from
typing
import
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
..utils.types
import
ProcessGroup
,
Tensor
,
TensorLike
@
unique
class
DeviceType
(
str
,
Enum
):
CPU
=
"cpu"
CUDA
=
"cuda"
META
=
"meta"
MPS
=
"mps"
NPU
=
"npu"
XPU
=
"xpu"
@
unique
class
ReduceOp
(
str
,
Enum
):
SUM
=
"sum"
MEAN
=
"mean"
MAX
=
"max"
MIN
=
"min"
def
requires_accelerator
(
fn
):
"""Decorator to check if torch.accelerator is available.
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
"""
@
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
hasattr
(
torch
,
"accelerator"
):
raise
RuntimeError
(
"torch.accelerator is not available, please upgrade torch to 2.7.0 or higher."
)
return
fn
(
*
args
,
**
kwargs
)
return
wrapper
def
is_distributed
()
->
bool
:
"""Check if distributed environment is available."""
return
os
.
getenv
(
"RANK"
)
is
not
None
def
get_rank
()
->
int
:
"""Get rank."""
return
int
(
os
.
getenv
(
"RANK"
,
"0"
))
def
get_world_size
()
->
int
:
"""Get world size."""
return
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
def
get_local_rank
()
->
int
:
"""Get local rank."""
return
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
def
get_local_world_size
()
->
int
:
"""Get local world size."""
return
int
(
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
"1"
))
@
lru_cache
@
requires_accelerator
def
get_current_accelerator
(
check_available
:
bool
=
True
)
->
torch
.
device
:
"""Get current accelerator."""
accelerator
=
torch
.
accelerator
.
current_accelerator
(
check_available
=
check_available
)
return
accelerator
or
torch
.
device
(
DeviceType
.
CPU
.
value
)
@
lru_cache
@
requires_accelerator
def
get_device_count
()
->
int
:
"""Get the number of available devices."""
return
torch
.
accelerator
.
device_count
()
@
requires_accelerator
def
synchronize
()
->
None
:
"""Synchronize all processes."""
torch
.
accelerator
.
synchronize
()
@
requires_accelerator
def
set_device
()
->
None
:
"""Set current accelerator."""
torch
.
accelerator
.
set_device_index
(
get_local_rank
())
def
is_torch_cuda_available
():
"""Check if CUDA is available."""
return
get_current_accelerator
().
type
==
DeviceType
.
CUDA
def
is_torch_mps_available
():
"""Check if MPS is available."""
return
get_current_accelerator
().
type
==
DeviceType
.
MPS
def
is_torch_npu_available
():
"""Check if NPU is available."""
return
get_current_accelerator
().
type
==
DeviceType
.
NPU
def
is_torch_xpu_available
():
"""Check if XPU is available."""
return
get_current_accelerator
().
type
==
DeviceType
.
XPU
def
operate_tensorlike
(
fn
:
Callable
[[...],
Tensor
],
data
:
TensorLike
,
**
kwargs
)
->
TensorLike
:
"""Operate tensorlike data on current accelerator."""
device
=
get_current_accelerator
()
is_tensor
=
isinstance
(
data
,
torch
.
Tensor
)
is_ndarray
=
isinstance
(
data
,
np
.
ndarray
)
if
is_tensor
:
orig_device
=
data
.
device
data
=
data
.
to
(
device
=
device
)
elif
is_ndarray
:
data
=
torch
.
from_numpy
(
data
).
to
(
device
=
device
,
dtype
=
torch
.
float
)
else
:
data
=
torch
.
tensor
(
data
,
dtype
=
torch
.
float
,
device
=
device
)
result
=
fn
(
data
,
**
kwargs
)
if
is_tensor
:
return
result
.
to
(
orig_device
)
elif
is_ndarray
:
return
result
.
cpu
().
numpy
()
elif
result
.
numel
()
==
1
:
return
result
.
item
()
else
:
return
result
.
tolist
()
def
all_gather
(
tensor
:
Tensor
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tensor
:
"""Gathers the tensor from all ranks and stacks them at the first dim."""
world_size
=
get_world_size
()
output_tensor
=
torch
.
empty
(
world_size
*
tensor
.
numel
(),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
dist
.
all_gather_into_tensor
(
output_tensor
,
tensor
,
group
=
group
)
return
output_tensor
.
view
(
-
1
,
*
tensor
.
size
())
def
all_reduce
(
tensor
:
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
MEAN
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tensor
:
"""Performs all reduce in the given process group."""
reduce_ops
=
{
ReduceOp
.
MEAN
:
dist
.
ReduceOp
.
SUM
,
ReduceOp
.
SUM
:
dist
.
ReduceOp
.
SUM
,
ReduceOp
.
MAX
:
dist
.
ReduceOp
.
MAX
,
ReduceOp
.
MIN
:
dist
.
ReduceOp
.
MIN
,
}
dist
.
all_reduce
(
tensor
,
op
=
reduce_ops
[
op
],
group
=
group
)
if
op
==
ReduceOp
.
MEAN
:
# ReduceOp.AVG is not supported by the NPU backend
tensor
/=
dist
.
get_world_size
(
group
=
group
)
return
tensor
def
broadcast
(
tensor
:
Tensor
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tensor
:
"""Broadcasts the tensor from the src process to all other processes."""
dist
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
)
return
tensor
@
contextmanager
def
main_process_first
(
local_only
:
bool
=
True
)
->
None
:
"""A context manager for torch distributed environment to do something on the main process firstly."""
if
get_world_size
()
>
1
:
is_main_process
=
get_local_rank
()
==
0
if
local_only
else
get_rank
()
==
0
try
:
if
not
is_main_process
:
dist
.
barrier
()
yield
finally
:
if
is_main_process
:
dist
.
barrier
()
else
:
yield
src/llamafactory/v1/accelerator/interface.py
0 → 100644
View file @
ca625f43
# Copyright 2025 Bytedance Ltd. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's VeOmni library.
# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/distributed/parallel_state.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A unified interface for model parallelism and data parallelism.
Supports model parallelism types:
- mp_replicate: Replicate model across multiple devices.
- mp_shard: Shard model across multiple devices.
And data parallelism types:
- dp: Data parallelism.
- cp: Context parallelism.
"""
from
dataclasses
import
dataclass
from
datetime
import
timedelta
from
enum
import
Enum
from
typing
import
Any
,
Optional
from
torch.distributed
import
barrier
,
destroy_process_group
,
init_process_group
from
torch.distributed.device_mesh
import
DeviceMesh
,
init_device_mesh
from
..utils.types
import
DistributedConfig
,
ProcessGroup
,
Tensor
,
TensorLike
from
.
import
helper
class
Dim
(
str
,
Enum
):
"""Dimension names."""
MP_REPLICATE
=
"mp_replicate"
MP_SHARD
=
"mp_shard"
DP
=
"dp"
CP
=
"cp"
@
dataclass
class
DistributedStrategy
:
"""Distributed strategy."""
mp_replicate_size
:
int
=
1
"""Model parallel replicate size, default to 1."""
mp_shard_size
:
int
|
None
=
None
"""Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size
:
int
|
None
=
None
"""Data parallel size, default to world_size // cp_size."""
cp_size
:
int
=
1
"""Context parallel size, default to 1."""
def
__post_init__
(
self
)
->
None
:
if
not
helper
.
is_distributed
():
self
.
mp_shard_size
=
1
elif
self
.
mp_shard_size
is
None
:
self
.
mp_shard_size
=
helper
.
get_world_size
()
//
self
.
mp_replicate_size
elif
self
.
mp_replicate_size
*
self
.
mp_shard_size
!=
helper
.
get_world_size
():
raise
ValueError
(
f
"mp_replicate_size * mp_shard_size must equal to world_size, "
f
"got
{
self
.
mp_replicate_size
}
*
{
self
.
mp_shard_size
}
!=
{
helper
.
get_world_size
()
}
."
)
if
not
helper
.
is_distributed
():
self
.
dp_size
=
1
elif
self
.
dp_size
is
None
:
self
.
dp_size
=
helper
.
get_world_size
()
//
self
.
cp_size
elif
self
.
dp_size
*
self
.
cp_size
!=
helper
.
get_world_size
():
raise
ValueError
(
f
"dp_size * cp_size must equal to world_size, "
f
"got
{
self
.
dp_size
}
*
{
self
.
cp_size
}
!=
{
helper
.
get_world_size
()
}
."
)
@
property
def
model_mesh_shape
(
self
)
->
tuple
[
int
,
int
]:
"""Model parallel mesh shape."""
return
(
self
.
mp_replicate_size
,
self
.
mp_shard_size
)
@
property
def
model_mesh_dim_names
(
self
)
->
tuple
[
str
,
str
]:
"""Model parallel mesh dimension names."""
return
(
Dim
.
MP_REPLICATE
.
value
,
Dim
.
MP_SHARD
.
value
)
@
property
def
data_mesh_shape
(
self
)
->
tuple
[
int
,
int
]:
"""Data parallel mesh shape."""
return
(
self
.
dp_size
,
self
.
cp_size
)
@
property
def
data_mesh_dim_names
(
self
)
->
tuple
[
str
,
str
]:
"""Data parallel mesh dimension names."""
return
(
Dim
.
DP
.
value
,
Dim
.
CP
.
value
)
class
DistributedInterface
:
"""Distributed interface."""
_instance
:
Optional
[
"DistributedInterface"
]
=
None
_initialized
:
bool
=
False
def
__new__
(
cls
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"DistributedInterface"
:
"""Singleton pattern."""
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
__init__
(
self
,
config
:
DistributedConfig
|
None
=
None
)
->
None
:
if
self
.
_initialized
:
return
self
.
_is_distributed
=
helper
.
is_distributed
()
self
.
_rank
=
helper
.
get_rank
()
self
.
_world_size
=
helper
.
get_world_size
()
self
.
_local_rank
=
helper
.
get_local_rank
()
self
.
_local_world_size
=
helper
.
get_local_world_size
()
self
.
current_accelerator
=
helper
.
get_current_accelerator
()
self
.
device_count
=
helper
.
get_device_count
()
if
config
is
None
:
self
.
strategy
=
DistributedStrategy
()
timeout
=
18000
else
:
self
.
strategy
=
DistributedStrategy
(
mp_replicate_size
=
config
.
get
(
"mp_replicate_size"
,
1
),
mp_shard_size
=
config
.
get
(
"mp_shard_size"
,
None
),
dp_size
=
config
.
get
(
"dp_size"
,
None
),
cp_size
=
config
.
get
(
"cp_size"
,
1
),
)
timeout
=
config
.
get
(
"timeout"
,
18000
)
if
self
.
_is_distributed
:
helper
.
set_device
()
init_process_group
(
timeout
=
timedelta
(
seconds
=
timeout
))
self
.
model_device_mesh
=
init_device_mesh
(
device_type
=
self
.
current_accelerator
.
type
,
mesh_shape
=
self
.
strategy
.
model_mesh_shape
,
mesh_dim_names
=
self
.
strategy
.
model_mesh_dim_names
,
)
self
.
data_device_mesh
=
init_device_mesh
(
device_type
=
self
.
current_accelerator
.
type
,
mesh_shape
=
self
.
strategy
.
data_mesh_shape
,
mesh_dim_names
=
self
.
strategy
.
data_mesh_dim_names
,
)
else
:
self
.
model_device_mesh
=
None
self
.
data_device_mesh
=
None
self
.
_initialized
=
True
def
__str__
(
self
)
->
str
:
return
(
f
"DistributedInterface(strategy=
{
self
.
strategy
}
), is_distributed=
{
self
.
_is_distributed
}
, "
f
"current_accelerator=
{
self
.
current_accelerator
}
, rank=
{
self
.
_rank
}
, world_size=
{
self
.
_world_size
}
, "
f
"model_device_mesh=
{
self
.
model_device_mesh
}
, data_device_mesh=
{
self
.
data_device_mesh
}
"
)
def
get_device_mesh
(
self
,
dim
:
Dim
|
None
=
None
)
->
DeviceMesh
|
None
:
"""Get device mesh for specified dimension."""
if
dim
is
None
:
raise
ValueError
(
"dim must be specified."
)
elif
self
.
model_device_mesh
is
None
:
return
None
elif
dim
in
self
.
strategy
.
data_mesh_dim_names
:
return
self
.
data_device_mesh
[
dim
.
value
]
else
:
return
self
.
model_device_mesh
[
dim
.
value
]
def
get_group
(
self
,
dim
:
Dim
|
None
=
None
)
->
Optional
[
ProcessGroup
]:
"""Get process group for specified dimension."""
if
self
.
model_device_mesh
is
None
or
dim
is
None
:
return
None
else
:
return
self
.
get_device_mesh
(
dim
).
get_group
()
def
get_rank
(
self
,
dim
:
Dim
|
None
=
None
)
->
int
:
"""Get parallel rank for specified dimension."""
if
self
.
model_device_mesh
is
None
:
return
0
elif
dim
is
None
:
return
self
.
_rank
else
:
return
self
.
get_device_mesh
(
dim
).
get_local_rank
()
def
get_world_size
(
self
,
dim
:
Dim
|
None
=
None
)
->
int
:
"""Get parallel size for specified dimension."""
if
self
.
model_device_mesh
is
None
:
return
1
elif
dim
is
None
:
return
self
.
_world_size
else
:
return
self
.
get_device_mesh
(
dim
).
size
()
def
get_local_rank
(
self
)
->
int
:
"""Get parallel local rank."""
return
self
.
_local_rank
def
get_local_world_size
(
self
)
->
int
:
"""Get parallel local world size."""
return
self
.
_local_world_size
def
all_gather
(
self
,
data
:
Tensor
,
dim
:
Dim
|
None
=
Dim
.
DP
)
->
Tensor
:
"""Gather tensor across specified parallel group."""
if
self
.
model_device_mesh
is
not
None
:
return
helper
.
operate_tensorlike
(
helper
.
all_gather
,
data
,
group
=
self
.
get_group
(
dim
))
else
:
return
data
def
all_reduce
(
self
,
data
:
TensorLike
,
op
:
helper
.
ReduceOp
=
helper
.
ReduceOp
.
MEAN
,
dim
:
Dim
|
None
=
Dim
.
DP
)
->
TensorLike
:
"""Reduce tensor across specified parallel group."""
if
self
.
model_device_mesh
is
not
None
:
return
helper
.
operate_tensorlike
(
helper
.
all_reduce
,
data
,
op
=
op
,
group
=
self
.
get_group
(
dim
))
else
:
return
data
def
broadcast
(
self
,
data
:
TensorLike
,
src
:
int
=
0
,
dim
:
Dim
|
None
=
Dim
.
DP
)
->
TensorLike
:
"""Broadcast tensor across specified parallel group."""
if
self
.
model_device_mesh
is
not
None
:
return
helper
.
operate_tensorlike
(
helper
.
broadcast
,
data
,
src
=
src
,
group
=
self
.
get_group
(
dim
))
else
:
return
data
def
sync
(
self
)
->
None
:
"""Synchronize all processes."""
helper
.
synchronize
()
def
barrier
(
self
)
->
None
:
"""Barrier all processes."""
barrier
()
def
destroy
(
self
)
->
None
:
"""Destroy all processes."""
destroy_process_group
()
if
__name__
==
"__main__"
:
print
(
DistributedInterface
(
DistributedStrategy
()))
src/llamafactory/v1/accelerator/profiler.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/config/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/config/arg_parser.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
sys
from
pathlib
import
Path
from
typing
import
Any
from
omegaconf
import
OmegaConf
from
transformers
import
HfArgumentParser
from
...extras.misc
import
is_env_enabled
from
.data_args
import
DataArguments
from
.model_args
import
ModelArguments
from
.sample_args
import
SampleArguments
from
.training_args
import
TrainingArguments
InputArgument
=
dict
[
str
,
Any
]
|
list
[
str
]
|
None
def
validate_args
(
data_args
:
DataArguments
,
model_args
:
ModelArguments
,
training_args
:
TrainingArguments
,
sample_args
:
SampleArguments
,
):
"""Validate arguments."""
if
(
model_args
.
quant_config
is
not
None
and
training_args
.
dist_config
is
not
None
and
training_args
.
dist_config
.
name
==
"deepspeed"
):
raise
ValueError
(
"Quantization is not supported with deepspeed backend."
)
def
get_args
(
args
:
InputArgument
=
None
)
->
tuple
[
DataArguments
,
ModelArguments
,
TrainingArguments
,
SampleArguments
]:
"""Parse arguments from command line or config file."""
parser
=
HfArgumentParser
([
DataArguments
,
ModelArguments
,
TrainingArguments
,
SampleArguments
])
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_KEYS"
)
if
args
is
None
:
if
len
(
sys
.
argv
)
>
1
and
(
sys
.
argv
[
1
].
endswith
(
".yaml"
)
or
sys
.
argv
[
1
].
endswith
(
".yml"
)):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
OmegaConf
.
load
(
Path
(
sys
.
argv
[
1
]).
absolute
())
args
=
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
elif
len
(
sys
.
argv
)
>
1
and
sys
.
argv
[
1
].
endswith
(
".json"
):
override_config
=
OmegaConf
.
from_cli
(
sys
.
argv
[
2
:])
dict_config
=
OmegaConf
.
create
(
json
.
load
(
Path
(
sys
.
argv
[
1
]).
absolute
()))
args
=
OmegaConf
.
to_container
(
OmegaConf
.
merge
(
dict_config
,
override_config
))
else
:
# list of strings
args
=
sys
.
argv
[
1
:]
if
isinstance
(
args
,
dict
):
(
*
parsed_args
,)
=
parser
.
parse_dict
(
args
,
allow_extra_keys
=
allow_extra_keys
)
else
:
(
*
parsed_args
,
unknown_args
)
=
parser
.
parse_args_into_dataclasses
(
args
,
return_remaining_strings
=
True
)
if
unknown_args
and
not
allow_extra_keys
:
print
(
parser
.
format_help
())
print
(
f
"Got unknown args, potentially deprecated arguments:
{
unknown_args
}
"
)
raise
ValueError
(
f
"Some specified arguments are not used by the HfArgumentParser:
{
unknown_args
}
"
)
validate_args
(
*
parsed_args
)
return
tuple
(
parsed_args
)
if
__name__
==
"__main__"
:
print
(
get_args
())
src/llamafactory/v1/config/arg_utils.py
0 → 100644
View file @
ca625f43
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/training_args.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
enum
import
Enum
,
unique
class
PluginConfig
(
dict
):
"""Dictionary that allows attribute access."""
@
property
def
name
(
self
)
->
str
:
"""Plugin name."""
if
"name"
not
in
self
:
raise
ValueError
(
"Plugin configuration must have a 'name' field."
)
return
self
[
"name"
]
PluginArgument
=
PluginConfig
|
dict
|
str
|
None
@
unique
class
ModelClass
(
str
,
Enum
):
"""Auto class for model config."""
LLM
=
"llm"
CLS
=
"cls"
OTHER
=
"other"
@
unique
class
SampleBackend
(
str
,
Enum
):
HF
=
"hf"
VLLM
=
"vllm"
def
_convert_str_dict
(
data
:
dict
)
->
dict
:
"""Parse string representation inside the dictionary.
Args:
data: The string or dictionary to convert.
Returns:
The converted dictionary.
"""
for
key
,
value
in
data
.
items
():
if
isinstance
(
value
,
dict
):
data
[
key
]
=
_convert_str_dict
(
value
)
elif
isinstance
(
value
,
str
):
if
value
.
lower
()
in
(
"true"
,
"false"
):
data
[
key
]
=
value
.
lower
()
==
"true"
elif
value
.
isdigit
():
data
[
key
]
=
int
(
value
)
elif
value
.
replace
(
"."
,
""
,
1
).
isdigit
():
data
[
key
]
=
float
(
value
)
return
data
def
get_plugin_config
(
config
:
PluginArgument
)
->
PluginConfig
|
None
:
"""Get the plugin configuration from the argument value.
Args:
config: The argument value to get the plugin configuration from.
Returns:
The plugin configuration.
"""
if
config
is
None
:
return
None
if
isinstance
(
config
,
str
)
and
config
.
startswith
(
"{"
):
config
=
json
.
loads
(
config
)
config
=
_convert_str_dict
(
config
)
if
"name"
not
in
config
:
raise
ValueError
(
"Plugin configuration must have a 'name' field."
)
return
PluginConfig
(
config
)
src/llamafactory/v1/config/data_args.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
,
field
@
dataclass
class
DataArguments
:
dataset
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the dataset."
},
)
cutoff_len
:
int
=
field
(
default
=
2048
,
metadata
=
{
"help"
:
"Cutoff length for the dataset."
},
)
src/llamafactory/v1/config/model_args.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
,
field
from
.arg_utils
import
ModelClass
,
PluginConfig
,
get_plugin_config
@
dataclass
class
ModelArguments
:
model
:
str
=
field
(
metadata
=
{
"help"
:
"Path to the model or model identifier from Hugging Face."
},
)
trust_remote_code
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Trust remote code from Hugging Face."
},
)
use_fast_processor
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Use fast processor from Hugging Face."
},
)
model_class
:
ModelClass
=
field
(
default
=
ModelClass
.
LLM
,
metadata
=
{
"help"
:
"Model class from Hugging Face."
},
)
peft_config
:
PluginConfig
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"PEFT configuration for the model."
},
)
kernel_config
:
PluginConfig
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Kernel configuration for the model."
},
)
quant_config
:
PluginConfig
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Quantization configuration for the model."
},
)
def
__post_init__
(
self
)
->
None
:
self
.
peft_config
=
get_plugin_config
(
self
.
peft_config
)
self
.
kernel_config
=
get_plugin_config
(
self
.
kernel_config
)
self
.
quant_config
=
get_plugin_config
(
self
.
quant_config
)
src/llamafactory/v1/config/sample_args.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
,
field
from
.arg_utils
import
SampleBackend
@
dataclass
class
SampleArguments
:
sample_backend
:
SampleBackend
=
field
(
default
=
SampleBackend
.
HF
,
metadata
=
{
"help"
:
"Sampling backend, default to 'hf'."
},
)
max_new_tokens
:
int
=
field
(
default
=
128
,
metadata
=
{
"help"
:
"Maximum number of new tokens to generate."
},
)
src/llamafactory/v1/config/training_args.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
dataclasses
import
dataclass
,
field
from
uuid
import
uuid4
from
.arg_utils
import
PluginConfig
,
get_plugin_config
@
dataclass
class
TrainingArguments
:
output_dir
:
str
=
field
(
default
=
os
.
path
.
join
(
"outputs"
,
str
(
uuid4
())),
metadata
=
{
"help"
:
"Path to the output directory."
},
)
micro_batch_size
:
int
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"Micro batch size for training."
},
)
global_batch_size
:
int
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"Global batch size for training."
},
)
learning_rate
:
float
=
field
(
default
=
1e-4
,
metadata
=
{
"help"
:
"Learning rate for training."
},
)
bf16
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use bf16 for training."
},
)
dist_config
:
PluginConfig
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Distribution configuration for training."
},
)
def
__post_init__
(
self
)
->
None
:
self
.
dist_config
=
get_plugin_config
(
self
.
dist_config
)
src/llamafactory/v1/core/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/core/base_sampler.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections.abc
import
AsyncGenerator
from
..config
import
ModelArguments
,
SampleArguments
,
SampleBackend
from
..utils.types
import
HFModel
,
Message
,
Sample
,
TorchDataset
from
.utils.inference_engine
import
HuggingFaceEngine
from
.utils.rendering
import
Renderer
class
BaseSampler
:
"""Base sampler.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
def
__init__
(
self
,
args
:
SampleArguments
,
model_args
:
ModelArguments
,
model
:
HFModel
,
renderer
:
Renderer
,
)
->
None
:
if
args
.
sample_backend
==
SampleBackend
.
HF
:
self
.
engine
=
HuggingFaceEngine
(
args
,
model_args
,
model
,
renderer
)
else
:
raise
ValueError
(
f
"Unknown sample backend:
{
args
.
sample_backend
}
"
)
async
def
generate
(
self
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
)
->
AsyncGenerator
[
str
,
None
]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
async
for
token
in
self
.
engine
.
generate
(
messages
,
tools
):
yield
token
async
def
batch_infer
(
self
,
dataset
:
TorchDataset
)
->
list
[
Sample
]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return
await
self
.
engine
.
batch_infer
(
dataset
)
src/llamafactory/v1/core/base_trainer.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of trainer.
Init Phase:
1. Init dataloader.
2. Init optimizer (deepspeed).
3. Shard model.
4. Init optimizer (fsdp).
5. Init scheduler.
Train Phase:
1. Train Loop
"""
from
..config.training_args
import
TrainingArguments
from
..utils.types
import
HFModel
,
Processor
,
TorchDataset
from
.trainer_utils.data_collator
import
DataCollator
class
BaseTrainer
:
def
__init__
(
self
,
args
:
TrainingArguments
,
model
:
HFModel
,
processor
:
Processor
,
dataset
:
TorchDataset
,
)
->
None
:
self
.
args
=
args
self
.
model
=
model
self
.
processor
=
processor
self
.
dataset
=
dataset
self
.
data_collator
=
DataCollator
()
self
.
optimizer
=
None
self
.
lr_scheduler
=
None
def
init_model_and_optimizer
(
self
)
->
None
:
pass
def
create_dataloader
(
self
)
->
None
:
pass
def
fit
(
self
)
->
None
:
pass
src/llamafactory/v1/core/chat_sampler.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
,
abstractmethod
from
..config.sample_args
import
SampleArguments
,
SampleBackend
from
.model_loader
import
ModelLoader
class
BaseEngine
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
sample_args
:
SampleArguments
,
model_loader
:
ModelLoader
)
->
None
:
...
@
abstractmethod
async
def
generate
(
self
):
pass
@
abstractmethod
async
def
batch_infer
(
self
):
pass
class
HuggingFaceEngine
(
BaseEngine
):
def
__init__
(
self
,
model_loader
:
ModelLoader
,
sample_args
:
SampleArguments
)
->
None
:
self
.
args
=
sample_args
class
ChatSampler
:
def
__init__
(
self
,
model_loader
:
ModelLoader
,
sample_args
:
SampleArguments
)
->
None
:
if
sample_args
.
sample_backend
==
SampleBackend
.
HF
:
self
.
engine
=
HuggingFaceEngine
(
model_loader
,
sample_args
)
else
:
raise
ValueError
(
f
"Unknown sample backend:
{
sample_args
.
sample_backend
}
"
)
src/llamafactory/v1/core/data_engine.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of data engine.
Init Data engine:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
"""
import
os
from
collections.abc
import
Iterable
from
typing
import
Any
from
huggingface_hub
import
hf_hub_download
from
omegaconf
import
OmegaConf
from
torch.utils.data
import
Dataset
from
..config.data_args
import
DataArguments
from
..utils.types
import
DatasetInfo
,
HFDataset
,
Sample
class
DataEngine
(
Dataset
):
"""Data engine.
Args:
data_args: Data arguments.
"""
def
__init__
(
self
,
data_args
:
DataArguments
)
->
None
:
self
.
args
=
data_args
"""Data arguments."""
self
.
datasets
:
dict
[
str
,
HFDataset
]
=
{}
"""Dict of (dataset_name, dataset)"""
self
.
dataset_infos
:
dict
[
str
,
DatasetInfo
]
=
{}
"""Dict of (dataset_name, dataset_info)"""
self
.
data_index
:
list
[
tuple
[
str
,
int
]]
=
[]
"""List of (dataset_name, sample_index)"""
self
.
streaming
:
bool
=
False
"""Whether dataset is streaming."""
self
.
_get_dataset_info
()
self
.
_load_dataset
()
self
.
_build_data_index
()
def
_get_dataset_info
(
self
)
->
None
:
"""Get dataset info from data arguments."""
if
self
.
args
.
dataset
.
endswith
(
".yaml"
)
and
os
.
path
.
isfile
(
self
.
args
.
dataset
):
# local file
self
.
dataset_infos
=
OmegaConf
.
load
(
self
.
args
.
dataset
)
elif
self
.
args
.
dataset
.
endswith
(
".yaml"
):
# hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id
,
filename
=
os
.
path
.
split
(
self
.
args
.
dataset
)
filepath
=
hf_hub_download
(
repo_id
=
repo_id
,
filename
=
filename
,
repo_type
=
"dataset"
)
self
.
dataset_infos
=
OmegaConf
.
load
(
filepath
)
elif
os
.
path
.
exists
(
self
.
args
.
dataset
):
# local file(s)
self
.
dataset_infos
=
{
"default"
:
{
"path"
:
self
.
args
.
dataset
,
"source"
:
"local"
}}
else
:
# hf hub dataset, e.g. llamafactory/v1-sft-demo
self
.
dataset_infos
=
{
"default"
:
{
"path"
:
self
.
args
.
dataset
}}
def
_load_dataset
(
self
)
->
None
:
"""Load datasets according to dataset info."""
for
dataset_name
,
dataset_info
in
self
.
dataset_infos
.
items
():
split
=
dataset_info
.
get
(
"split"
,
"train"
)
streaming
=
dataset_info
.
get
(
"streaming"
,
False
)
self
.
streaming
|=
streaming
if
dataset_info
.
get
(
"source"
,
"hf_hub"
)
==
"hf_hub"
:
from
datasets
import
load_dataset
self
.
datasets
[
dataset_name
]
=
load_dataset
(
dataset_info
[
"path"
],
split
=
split
,
streaming
=
streaming
)
else
:
# data loader plugin
from
..plugins.data_plugins.loader
import
DataLoaderPlugin
self
.
datasets
[
dataset_name
]
=
DataLoaderPlugin
(
dataset_info
[
"source"
]).
load
(
dataset_info
)
def
_build_data_index
(
self
)
->
None
:
"""Build dataset index."""
for
dataset_name
,
dataset
in
self
.
datasets
.
items
():
streaming
=
self
.
dataset_infos
[
dataset_name
].
get
(
"streaming"
,
False
)
if
streaming
:
data_index
=
[(
dataset_name
,
-
1
)
for
_
in
range
(
1000
)]
else
:
data_index
=
[(
dataset_name
,
sample_index
)
for
sample_index
in
range
(
len
(
dataset
))]
size
=
self
.
dataset_infos
[
dataset_name
].
get
(
"size"
)
weight
=
self
.
dataset_infos
[
dataset_name
].
get
(
"weight"
)
if
size
or
weight
:
# data index plugin
from
..plugins.data_plugins.loader
import
DataIndexPlugin
data_index
=
DataIndexPlugin
().
adjust_data_index
(
data_index
,
size
,
weight
)
self
.
data_index
.
extend
(
data_index
)
def
_convert_data_sample
(
self
,
raw_sample
:
dict
[
str
,
Any
],
dataset_name
:
str
)
->
Sample
:
"""Convert dataset sample.
Args:
raw_sample (dict[str, Any]): Raw dataset sample.
dataset_name (str): Dataset name.
Returns:
Sample: Dataset sample.
"""
converter
=
self
.
dataset_infos
[
dataset_name
].
get
(
"converter"
)
if
converter
is
not
None
:
from
..plugins.data_plugins.converter
import
DataConverterPlugin
return
{
"_dataset_name"
:
dataset_name
,
**
DataConverterPlugin
(
converter
)(
raw_sample
)}
else
:
return
{
"_dataset_name"
:
dataset_name
,
**
raw_sample
}
def
__len__
(
self
)
->
int
:
"""Get dataset length.
Returns:
int: Dataset length.
"""
if
self
.
streaming
:
return
-
1
else
:
return
len
(
self
.
data_index
)
def
__getitem__
(
self
,
index
:
int
|
Any
)
->
Sample
|
list
[
Sample
]:
"""Get dataset item.
Args:
index (int): Dataset index.
Returns:
Sample: Dataset item.
"""
if
self
.
streaming
:
raise
ValueError
(
"Streaming dataset does not support index access."
)
if
isinstance
(
index
,
int
):
dataset_name
,
sample_index
=
self
.
data_index
[
index
]
return
self
.
_convert_data_sample
(
self
.
datasets
[
dataset_name
][
sample_index
],
dataset_name
)
else
:
# data selector plugin
from
..plugins.data_plugins.loader
import
DataSelectorPlugin
selected_index
=
DataSelectorPlugin
().
select
(
self
.
data_index
,
index
)
if
isinstance
(
selected_index
,
list
):
return
[
self
.
_convert_data_sample
(
self
.
datasets
[
dataset_name
][
sample_index
],
dataset_name
)
for
dataset_name
,
sample_index
in
selected_index
]
else
:
dataset_name
,
sample_index
=
selected_index
return
self
.
_convert_data_sample
(
self
.
datasets
[
dataset_name
][
sample_index
],
dataset_name
)
def
__iter__
(
self
)
->
Iterable
[
Sample
]:
"""Get dataset iterator.
Returns:
Iterable[Sample]: Dataset iterator.
"""
# NOTE: hf iterable dataset uses worker ids while map dataset does not
# NOTE: add worker id and shuffle to the map dataset
# https://github.com/huggingface/datasets/blob/4.0.0/src/datasets/iterable_dataset.py#L2214
raise
NotImplementedError
()
if
__name__
==
"__main__"
:
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
"""
from
..config.arg_parser
import
get_args
data_args
,
*
_
=
get_args
()
data_engine
=
DataEngine
(
data_args
=
data_args
)
print
(
data_engine
[
0
])
src/llamafactory/v1/core/model_engine.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model engine.
How to use:
model_engine = ModelEngine(model_args, is_train=True)
model_engine.processor: Get the tokenizer or multi-modal processor.
model_engine.renderer: Get the renderer.
model_engine.model_config: Get the model configuration.
model_engine.model: Get the HF model.
Init workflow:
1. Init processor.
2. Init render.
2. Init model config.
3. Init model.
4. Init adapter.
"""
import
torch
from
accelerate
import
init_empty_weights
from
transformers
import
AutoConfig
,
AutoProcessor
from
..accelerator.helper
import
DeviceType
from
..accelerator.interface
import
DistributedInterface
from
..config.model_args
import
ModelArguments
,
ModelClass
from
..utils
import
logging
from
..utils.types
import
HFConfig
,
HFModel
,
Processor
from
.utils.rendering
import
Renderer
logger
=
logging
.
get_logger
(
__name__
)
class
ModelEngine
:
"""Model engine.
Args:
model_args: Model arguments.
is_train: Whether to train the model.
"""
def
__init__
(
self
,
model_args
:
ModelArguments
,
is_train
:
bool
=
False
)
->
None
:
self
.
args
=
model_args
"""Model arguments."""
self
.
is_train
=
is_train
"""Whether to train the model."""
self
.
processor
=
self
.
_init_processor
()
"""Tokenizer or multi-modal processor."""
self
.
renderer
=
Renderer
(
self
.
args
.
template
,
self
.
processor
)
"""Renderer."""
self
.
model_config
=
self
.
_init_model_config
()
"""Model configuration."""
self
.
model
=
self
.
_init_model
()
"""HF model."""
def
_init_processor
(
self
)
->
Processor
:
"""Init processor.
NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return
AutoProcessor
.
from_pretrained
(
self
.
args
.
model
,
trust_remote_code
=
self
.
args
.
trust_remote_code
,
)
def
_init_model_config
(
self
)
->
HFConfig
:
"""Init model config."""
return
AutoConfig
.
from_pretrained
(
self
.
args
.
model
,
trust_remote_code
=
self
.
args
.
trust_remote_code
,
)
def
_init_model
(
self
)
->
HFModel
:
"""Init model.
Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if
self
.
args
.
init_config
is
not
None
:
from
..plugins.model_plugins.initialization
import
InitPlugin
init_device
=
InitPlugin
(
self
.
args
.
init_config
.
name
)()
else
:
init_device
=
DistributedInterface
().
current_device
init_kwargs
=
{
"device_map"
:
init_device
}
if
self
.
args
.
quant_config
is
not
None
:
from
..plugins.model_plugins.quantization
import
QuantizationPlugin
init_kwargs
=
QuantizationPlugin
(
self
.
args
.
quant_config
.
name
)(
init_kwargs
=
init_kwargs
,
config
=
self
.
model_config
,
tokenizer
=
self
.
processor
,
model_args
=
self
.
args
,
is_trainable
=
self
.
is_train
,
)
if
self
.
args
.
model_class
==
ModelClass
.
LLM
:
from
transformers
import
AutoModelForCausalLM
,
AutoModelForImageTextToText
if
type
(
self
.
model_config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
AutoClass
=
AutoModelForImageTextToText
else
:
AutoClass
=
AutoModelForCausalLM
elif
self
.
args
.
model_class
==
ModelClass
.
CLS
:
from
transformers
import
AutoModelForTokenClassification
AutoClass
=
AutoModelForTokenClassification
else
:
from
transformers
import
AutoModel
AutoClass
=
AutoModel
if
init_device
.
type
==
DeviceType
.
META
:
assert
self
.
args
.
quant_config
is
None
,
"Quantization is not supported with meta device."
with
init_empty_weights
():
model
=
AutoClass
.
from_config
(
self
.
model_config
)
else
:
model
=
AutoClass
.
from_pretrained
(
self
.
args
.
model
,
config
=
self
.
model_config
,
dtype
=
"auto"
,
trust_remote_code
=
self
.
args
.
trust_remote_code
,
**
init_kwargs
,
)
if
self
.
args
.
peft_config
is
None
:
if
self
.
is_train
:
logger
.
info_rank0
(
"Fine-tuning mode: full tuning"
)
model
=
model
.
to
(
torch
.
float32
)
else
:
logger
.
info_rank0
(
"Inference the original model"
)
else
:
from
..plugins.model_plugins.peft
import
PeftPlugin
model
=
PeftPlugin
(
self
.
args
.
peft_config
.
name
)(
model
,
self
.
args
.
peft_config
,
self
.
is_train
)
if
self
.
args
.
kernel_config
is
not
None
:
from
..plugins.model_plugins.kernels.interface
import
KernelPlugin
model
=
KernelPlugin
(
self
.
args
.
kernel_config
.
name
)(
model
,
include_kernels
=
self
.
args
.
kernel_config
.
get
(
"include_kernels"
)
)
return
model
if
__name__
==
"__main__"
:
"""
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
"""
from
..config.arg_parser
import
get_args
model_args
,
*
_
=
get_args
()
model_engine
=
ModelEngine
(
model_args
=
model_args
)
print
(
model_engine
.
processor
)
print
(
model_engine
.
model_config
)
print
(
model_engine
.
model
)
src/llamafactory/v1/core/model_loader.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model loader.
Init Phase:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.
"""
import
torch
from
transformers
import
AutoConfig
,
AutoProcessor
from
..accelerator.interface
import
DistributedInterface
from
..config.model_args
import
ModelArguments
,
ModelClass
from
..utils
import
logging
from
..utils.types
import
HFConfig
,
HFModel
,
Processor
logger
=
logging
.
get_logger
(
__name__
)
class
ModelLoader
:
"""Model loader.
Args:
model_args: Model arguments.
is_trainable: Whether to train the model.
"""
def
__init__
(
self
,
model_args
:
ModelArguments
,
is_train
:
bool
=
False
)
->
None
:
self
.
args
=
model_args
"""Model arguments."""
self
.
is_train
=
is_train
"""Whether to train the model."""
self
.
processor
=
self
.
_init_processor
()
"""Tokenizer or multi-modal processor."""
self
.
model_config
=
self
.
_init_model_config
()
"""Model configuration."""
self
.
model
=
self
.
_init_model
()
"""HF model."""
def
_init_processor
(
self
)
->
Processor
:
"""Init processor."""
return
AutoProcessor
.
from_pretrained
(
self
.
args
.
model
,
trust_remote_code
=
self
.
args
.
trust_remote_code
,
use_fast
=
self
.
args
.
use_fast_processor
,
)
def
_init_model_config
(
self
)
->
HFConfig
:
"""Init model config."""
return
AutoConfig
.
from_pretrained
(
self
.
args
.
model
,
trust_remote_code
=
self
.
args
.
trust_remote_code
,
)
def
_init_model
(
self
)
->
HFModel
:
"""Init model.
Let transformers handle the model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if
self
.
args
.
model_class
==
ModelClass
.
LLM
:
from
transformers
import
AutoModelForCausalLM
,
AutoModelForImageTextToText
if
type
(
self
.
model_config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
AutoClass
=
AutoModelForImageTextToText
else
:
AutoClass
=
AutoModelForCausalLM
elif
self
.
args
.
model_class
==
ModelClass
.
CLS
:
from
transformers
import
AutoModelForTokenClassification
AutoClass
=
AutoModelForTokenClassification
else
:
from
transformers
import
AutoModel
AutoClass
=
AutoModel
# map the entire model to the current accelerator
model
=
AutoClass
.
from_pretrained
(
self
.
args
.
model
,
config
=
self
.
model_config
,
dtype
=
"auto"
,
device_map
=
DistributedInterface
().
current_accelerator
,
trust_remote_code
=
self
.
args
.
trust_remote_code
,
)
if
self
.
args
.
peft_config
is
None
:
if
self
.
is_train
:
logger
.
info_rank0
(
"Fine-tuning mode: full tuning"
)
model
=
model
.
to
(
torch
.
float32
)
else
:
logger
.
info_rank0
(
"Inference the original model"
)
else
:
from
..plugins.model_plugins.peft
import
PeftPlugin
model
=
PeftPlugin
(
self
.
args
.
peft_config
.
name
)(
model
,
self
.
args
.
peft_config
,
self
.
is_train
)
if
self
.
args
.
kernel_config
is
not
None
:
from
..plugins.model_plugins.kernels.interface
import
KernelPlugin
model
=
KernelPlugin
(
self
.
args
.
kernel_config
.
name
)(
model
=
model
,
include_kernels
=
self
.
args
.
kernel_config
.
get
(
"include_kernels"
)
)
return
model
if
__name__
==
"__main__"
:
"""
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
"""
from
..config.arg_parser
import
get_args
_
,
model_args
,
*
_
=
get_args
()
model_loader
=
ModelLoader
(
model_args
=
model_args
)
print
(
model_loader
.
processor
)
print
(
model_loader
.
model_config
)
print
(
model_loader
.
model
)
src/llamafactory/v1/core/trainer_utils/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/core/trainer_utils/callback.py
0 → 100644
View file @
ca625f43
Prev
1
…
7
8
9
10
11
12
13
14
15
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment