Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1603 additions
and
0 deletions
+1603
-0
src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
...lamafactory/v1/plugins/model_plugins/kernels/interface.py
+132
-0
src/llamafactory/v1/plugins/model_plugins/kernels/ops/__init__.py
...afactory/v1/plugins/model_plugins/kernels/ops/__init__.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/__init__.py
...tory/v1/plugins/model_plugins/kernels/ops/mlp/__init__.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py
...v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py
+342
-0
src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py
...ry/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py
+168
-0
src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/__init__.py
...v1/plugins/model_plugins/kernels/ops/rms_norm/__init__.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py
...lugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py
+90
-0
src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/__init__.py
...ory/v1/plugins/model_plugins/kernels/ops/rope/__init__.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py
...ory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py
+146
-0
src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
...llamafactory/v1/plugins/model_plugins/kernels/registry.py
+97
-0
src/llamafactory/v1/plugins/model_plugins/peft.py
src/llamafactory/v1/plugins/model_plugins/peft.py
+57
-0
src/llamafactory/v1/plugins/model_plugins/quantization.py
src/llamafactory/v1/plugins/model_plugins/quantization.py
+0
-0
src/llamafactory/v1/plugins/model_plugins/rendering.py
src/llamafactory/v1/plugins/model_plugins/rendering.py
+56
-0
src/llamafactory/v1/plugins/model_plugins/templates/__init__.py
...amafactory/v1/plugins/model_plugins/templates/__init__.py
+13
-0
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
+259
-0
src/llamafactory/v1/plugins/model_plugins/templates/qwen3_nothink.py
...ctory/v1/plugins/model_plugins/templates/qwen3_nothink.py
+209
-0
src/llamafactory/v1/plugins/sampler_plugins/__init__.py
src/llamafactory/v1/plugins/sampler_plugins/__init__.py
+0
-0
src/llamafactory/v1/plugins/sampler_plugins/vllm.py
src/llamafactory/v1/plugins/sampler_plugins/vllm.py
+0
-0
src/llamafactory/v1/plugins/trainer_plugins/__init__.py
src/llamafactory/v1/plugins/trainer_plugins/__init__.py
+0
-0
src/llamafactory/v1/plugins/trainer_plugins/batching.py
src/llamafactory/v1/plugins/trainer_plugins/batching.py
+34
-0
No files found.
src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of kernel interface.
Init Phase:
1. Scan all kernels.
2. Register default kernels.
3. Define kernel plugin.
"""
import
importlib
from
pathlib
import
Path
from
....utils.logging
import
get_logger
from
....utils.plugin
import
BasePlugin
from
.registry
import
Registry
logger
=
get_logger
(
__name__
)
def
scan_all_kernels
():
r
"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
Returns:
dict[str, type[BaseKernel]]: A dictionary of registered kernels.
.. note::
This function assumes that the ``ops`` directory is located in the same directory as this file.
It recursively searches for ``.py`` files and constructs the module path for import.
"""
ops_path
=
Path
(
__file__
).
parent
/
"ops"
if
not
ops_path
.
exists
():
return
base_package
=
__package__
for
file_path
in
ops_path
.
rglob
(
"*.py"
):
if
file_path
.
name
==
"__init__.py"
:
continue
# calculate the relative path:
# file_path = .../kernels_v2/ops/mlp/npu_swiglu.py
# rel_path = ops/mlp/npu_swiglu.py
rel_path
=
file_path
.
relative_to
(
Path
(
__file__
).
parent
)
# build module path:
module_name
=
"."
.
join
(
rel_path
.
parts
)[:
-
3
]
full_module_name
=
f
"
{
base_package
}
.
{
module_name
}
"
try
:
importlib
.
import_module
(
full_module_name
)
except
Exception
as
e
:
logger
.
warning
(
f
"[Kernel Registry] Failed to import
{
full_module_name
}
when loading kernels:
{
e
}
"
)
return
Registry
.
get_registered_kernels
()
default_kernels
=
scan_all_kernels
()
def
get_default_kernels
():
r
"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
"""
return
list
(
default_kernels
.
keys
())
def
apply_kernel
(
kernel_id
:
str
,
**
kwargs
):
r
"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance.
Returns:
HFModel: The model with applied kernel.
"""
kernel
=
default_kernels
.
get
(
kernel_id
)
if
kernel
is
None
:
raise
ValueError
(
f
"Kernel
{
kernel_id
}
not found"
)
kernel
.
apply
(
**
kwargs
)
class
KernelPlugin
(
BasePlugin
):
r
"""Plugin for managing kernel optimizations."""
pass
@
KernelPlugin
(
"auto"
).
register
def
apply_default_kernels
(
**
kwargs
):
r
"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
Returns:
HFModel: The model with applied kernels.
"""
if
not
kwargs
.
get
(
"include_kernels"
):
# None/False/empty string
return
kwargs
.
get
(
"model"
)
elif
kwargs
.
get
(
"include_kernels"
)
==
"auto"
or
kwargs
.
get
(
"include_kernels"
)
is
True
:
# True/auto
use_kernels
=
default_kernels
.
keys
()
else
:
use_kernels
=
kwargs
.
get
(
"include_kernels"
).
split
(
","
)
# "kernel_id1,kernel_id2,kernel_id3"
for
kernel
in
use_kernels
:
if
kernel
not
in
default_kernels
:
raise
ValueError
(
f
"Kernel
{
kernel
}
not found"
)
apply_kernel
(
kernel
,
**
kwargs
)
return
kwargs
.
get
(
"model"
)
src/llamafactory/v1/plugins/model_plugins/kernels/ops/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused MoE kernels.
Init Phase:
1. Define GMM functions.
2. Define NPU fused MoE functions.
3. Register NPU fused MoE kernel.
"""
import
types
import
torch
import
torch.nn.functional
as
F
try
:
import
torch_npu
except
ImportError
:
pass
from
......accelerator.helper
import
DeviceType
from
......utils.packages
import
is_transformers_version_greater_than
from
......utils.types
import
HFModel
from
...base
import
BaseKernel
from
...registry
import
register_kernel
class
GmmFunction
(
torch
.
autograd
.
Function
):
r
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
group_list
):
r
"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
x (Tensor): Input tensor.
weight (Tensor): Weight tensor.
group_list (list): List of group sizes.
Returns:
Tensor: The result of the grouped matrix multiplication.
"""
ctx
.
save_for_backward
(
x
,
weight
)
ctx
.
group_list
=
group_list
fwd_output
=
torch_npu
.
npu_grouped_matmul
(
[
x
],
[
weight
],
bias
=
None
,
group_list
=
group_list
,
split_item
=
2
,
group_type
=
0
,
group_list_type
=
1
)[
0
]
return
fwd_output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
r
"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
grad_output (Tensor): Gradient with respect to the output.
Returns:
tuple: Gradients with respect to input, weight, and None for group_list.
"""
input_tensor
,
weight
=
ctx
.
saved_tensors
group_list
=
ctx
.
group_list
weight
=
torch
.
transpose
(
weight
,
1
,
2
)
grad_input
=
torch_npu
.
npu_grouped_matmul
(
[
grad_output
],
[
weight
],
bias
=
None
,
group_list
=
group_list
,
split_item
=
2
,
group_type
=
0
,
group_list_type
=
1
)[
0
]
grad_weight
=
torch_npu
.
npu_grouped_matmul
(
[
input_tensor
.
T
],
[
grad_output
],
bias
=
None
,
group_list
=
group_list
,
split_item
=
3
,
group_type
=
2
,
group_list_type
=
1
,
)[
0
]
return
grad_input
,
grad_weight
,
None
class
HybridGmmFunction
(
torch
.
autograd
.
Function
):
r
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@
staticmethod
def
forward
(
ctx
,
num_experts
,
*
args
):
r
"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
num_experts (int): Number of experts.
*args: Variable length argument list containing inputs and weights.
Returns:
tuple: The outputs of the grouped matrix multiplication.
"""
x_list
=
list
(
args
[:
num_experts
])
weight_list
=
list
(
args
[
num_experts
:])
split_sizes
=
[
x
.
shape
[
0
]
for
x
in
x_list
]
ctx
.
split_sizes
=
split_sizes
ctx
.
num_experts
=
num_experts
ctx
.
save_for_backward
(
*
args
)
outputs
=
torch_npu
.
npu_grouped_matmul
(
x_list
,
weight_list
,
bias
=
None
,
group_list
=
None
,
split_item
=
0
,
group_type
=-
1
)
return
tuple
(
outputs
)
@
staticmethod
def
backward
(
ctx
,
*
grad_outputs
):
r
"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
*grad_outputs: Gradients with respect to the outputs.
Returns:
tuple: Gradients with respect to inputs and weights.
"""
saved_tensors
=
ctx
.
saved_tensors
num_experts
=
ctx
.
num_experts
split_sizes
=
ctx
.
split_sizes
x_list
=
list
(
saved_tensors
[:
num_experts
])
weight_list
=
list
(
saved_tensors
[
num_experts
:])
grad_outputs_contiguous
=
[
g
.
contiguous
()
for
g
in
grad_outputs
]
w_t_list
=
[
w
.
t
()
for
w
in
weight_list
]
grad_x_list
=
torch_npu
.
npu_grouped_matmul
(
grad_outputs_contiguous
,
# List[Tensor], 每个 [M_i, N]
w_t_list
,
# List[Tensor], 每个 [N, K] (view)
bias
=
None
,
group_list
=
None
,
split_item
=
0
,
group_type
=-
1
,
)
x_concat
=
torch
.
cat
(
x_list
,
dim
=
0
)
dy_concat
=
torch
.
cat
(
grad_outputs_contiguous
,
dim
=
0
)
# [Total_M, N]
group_list
=
torch
.
tensor
(
split_sizes
,
device
=
x_concat
.
device
,
dtype
=
torch
.
int64
)
grad_w_stack
=
torch_npu
.
npu_grouped_matmul
(
[
x_concat
.
t
()],
[
dy_concat
],
bias
=
None
,
group_list
=
group_list
,
split_item
=
3
,
group_type
=
2
,
group_list_type
=
1
,
)[
0
]
if
grad_w_stack
.
dim
()
==
3
:
grad_w_list
=
list
(
torch
.
unbind
(
grad_w_stack
,
dim
=
0
))
else
:
raise
RuntimeError
(
f
"Unexpected grad_w_stack shape:
{
grad_w_stack
.
shape
}
"
)
return
(
None
,
*
grad_x_list
,
*
grad_w_list
)
class
NpuMoeFused
:
r
"""Container for NPU fused MoE forward functions."""
@
staticmethod
def
npu_moe_experts_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_weights
:
torch
.
Tensor
,
router_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
hidden_states (Tensor): Input hidden states.
routing_weights (Tensor): Routing weights.
router_indices (Tensor): Router indices.
Returns:
Tensor: Output tensor after expert computation.
"""
batch_size
=
hidden_states
.
shape
[
0
]
hidden_states
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
permuted_hidden_states
,
row_ids_map
=
torch_npu
.
npu_moe_token_permute
(
hidden_states
,
router_indices
.
to
(
torch
.
int32
)
)
tokens_per_expert
=
torch
.
histc
(
router_indices
,
bins
=
self
.
num_experts
,
min
=
0
,
max
=
self
.
num_experts
)
intermediate_hidden_states
=
GmmFunction
.
apply
(
permuted_hidden_states
,
self
.
gate_up_proj
,
tokens_per_expert
)
intermediate_activations
=
torch_npu
.
npu_swiglu
(
intermediate_hidden_states
,
dim
=-
1
)
output
=
GmmFunction
.
apply
(
intermediate_activations
,
self
.
down_proj
,
tokens_per_expert
)
next_states
=
torch_npu
.
npu_moe_token_unpermute
(
output
,
row_ids_map
,
probs
=
routing_weights
)
next_states
=
next_states
.
view
(
batch_size
,
-
1
,
self
.
hidden_size
)
return
next_states
@
staticmethod
def
npu_moe_sparse_block_forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Forward pass for sparse MoE block using NPU optimization.
Args:
self: The MoE sparse block instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: The routed output.
"""
batch_size
=
hidden_states
.
shape
[
0
]
hidden_states
=
hidden_states
.
reshape
(
-
1
,
self
.
hidden_size
)
router_logits
=
self
.
gate
(
hidden_states
)
routing_weights
=
torch
.
nn
.
functional
.
softmax
(
router_logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
routing_weights
,
router_indices
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
=
routing_weights
/
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
routing_weights
=
routing_weights
.
to
(
hidden_states
.
dtype
)
hidden_states
=
hidden_states
.
reshape
(
batch_size
,
-
1
,
self
.
hidden_size
)
routed_out
=
self
.
experts
(
hidden_states
,
routing_weights
,
router_indices
)
return
routed_out
class
Qwen3NpuMoeFused
:
r
"""Container for Qwen3 NPU fused MoE forward functions."""
@
staticmethod
def
qwen3moe_sparse_moe_block_forward
(
self
,
hidden_states
:
torch
.
Tensor
):
r
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
hidden_states (Tensor): Input hidden states.
Returns:
tuple: A tuple containing the next states and router logits.
"""
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
router_logits
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
if
self
.
norm_topk_prob
:
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
routing_weights
=
routing_weights
.
to
(
hidden_states
.
dtype
)
permuted_hidden_states
,
row_ids_map
=
torch_npu
.
npu_moe_token_permute
(
hidden_states
,
selected_experts
.
int
())
tokens_per_expert
=
torch
.
histc
(
selected_experts
.
float
(),
bins
=
self
.
num_experts
,
min
=
0
,
max
=
self
.
num_experts
).
long
()
split_sizes
=
tokens_per_expert
.
tolist
()
input_list
=
list
(
torch
.
split
(
permuted_hidden_states
,
split_sizes
,
dim
=
0
))
gate_weights
=
[
e
.
gate_proj
.
weight
.
t
()
for
e
in
self
.
experts
]
up_weights
=
[
e
.
up_proj
.
weight
.
t
()
for
e
in
self
.
experts
]
down_weights
=
[
e
.
down_proj
.
weight
.
t
()
for
e
in
self
.
experts
]
gate_out_tuple
=
HybridGmmFunction
.
apply
(
len
(
input_list
),
*
input_list
,
*
gate_weights
)
up_out_tuple
=
HybridGmmFunction
.
apply
(
len
(
input_list
),
*
input_list
,
*
up_weights
)
inter_list
=
[
F
.
silu
(
g
)
*
u
for
g
,
u
in
zip
(
gate_out_tuple
,
up_out_tuple
)]
down_out_tuple
=
HybridGmmFunction
.
apply
(
len
(
inter_list
),
*
inter_list
,
*
down_weights
)
grouped_output
=
torch
.
cat
(
down_out_tuple
,
dim
=
0
)
next_states
=
torch_npu
.
npu_moe_token_unpermute
(
grouped_output
,
row_ids_map
,
probs
=
routing_weights
)
next_states
=
next_states
.
view
(
batch_size
,
sequence_length
,
-
1
)
return
next_states
,
router_logits
# moe patch config mapping
kernel_moe_mapping
=
{
"Qwen3VLMoeForConditionalGeneration"
:
{
"Qwen3VLMoeTextExperts"
:
NpuMoeFused
.
npu_moe_experts_forward
,
"Qwen3VLMoeTextSparseMoeBlock"
:
NpuMoeFused
.
npu_moe_sparse_block_forward
,
}
}
if
not
is_transformers_version_greater_than
(
"5.0.0"
):
kernel_moe_mapping
[
"Qwen3MoeForCausalLM"
]
=
{
"Qwen3MoeSparseMoeBlock"
:
Qwen3NpuMoeFused
.
qwen3moe_sparse_moe_block_forward
}
@
register_kernel
class
NpuFusedMoEKernel
(
BaseKernel
):
r
"""NPU Fused MoE Kernel implementation."""
_kernel_id
=
"npu_fused_moe"
_device
=
DeviceType
.
NPU
@
classmethod
def
apply
(
cls
,
**
kwargs
)
->
HFModel
:
r
"""Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched MoE forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model
=
kwargs
.
get
(
"model"
,
None
)
if
model
is
None
:
raise
ValueError
(
f
"HFModel instance is required for
{
cls
.
__name__
}
."
)
if
not
cls
.
check_deps
():
raise
RuntimeError
(
"torch_npu is not available but NpuMoEFusedMoEKernel was called."
)
archs
=
getattr
(
model
.
config
,
"architectures"
,
[])
target_moe_mapping
=
None
for
arch
in
archs
:
if
arch
in
kernel_moe_mapping
:
target_moe_mapping
=
kernel_moe_mapping
[
arch
]
break
if
target_moe_mapping
is
None
:
return
model
for
module
in
model
.
modules
():
class_name
=
module
.
__class__
.
__name__
if
class_name
in
target_moe_mapping
:
new_forward_func
=
target_moe_mapping
[
class_name
]
module
.
forward
=
types
.
MethodType
(
new_forward_func
,
module
)
return
model
src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused SwiGLU kernels.
Init Phase:
1. Define SwiGLU forward functions.
2. Register NPU fused SwiGLU kernel.
"""
import
re
import
types
import
torch
from
......accelerator.helper
import
DeviceType
from
......utils.types
import
HFModel
from
...base
import
BaseKernel
from
...registry
import
register_kernel
try
:
import
torch_npu
except
ImportError
:
pass
def
npu_swiglu_forward
(
self
,
hidden_state
):
r
"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
hidden_state (Tensor): Input hidden state.
Returns:
Tensor: Output of SwiGLU.
"""
return
self
.
down_proj
(
torch_npu
.
npu_swiglu
(
torch
.
cat
((
self
.
gate_proj
(
hidden_state
),
self
.
up_proj
(
hidden_state
)),
dim
=-
1
),
dim
=-
1
)
)
def
_npu_swiglu_glm4_forward
(
self
,
hidden_states
):
r
"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate
,
up_states
=
up_states
.
chunk
(
2
,
dim
=-
1
)
return
self
.
down_proj
(
torch_npu
.
npu_swiglu
(
torch
.
cat
((
gate
,
up_states
),
dim
=-
1
),
dim
=-
1
))
def
_npu_swiglu_gemma3ntext_forward
(
self
,
hidden_states
):
r
"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
hidden_states (Tensor): Input hidden states.
Returns:
Tensor: Output of SwiGLU.
"""
gate_proj
=
self
.
gate_proj
(
hidden_states
)
if
self
.
activation_sparsity
>
0.0
:
gate_proj
=
self
.
_gaussian_topk
(
gate_proj
)
down_proj
=
self
.
down_proj
(
torch_npu
.
npu_swiglu
(
torch
.
cat
((
gate_proj
,
self
.
up_proj
(
hidden_states
)),
dim
=-
1
),
dim
=-
1
)
)
return
down_proj
@
register_kernel
class
NpuSwiGluKernel
(
BaseKernel
):
r
"""NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers
expect_modules
=
frozenset
(
{
"Qwen3VLMoeTextMLP"
,
"Qwen3VLTextMLP"
,
"Qwen3OmniMoeThinkerTextMLP"
,
"Qwen3OmniMoeMLP"
,
"Qwen3OmniMoeTalkerTextMLP"
,
"Qwen3OmniMoeCode2WavMlp"
,
"Qwen3NextMLP"
,
"Qwen3MoeMLP"
,
"Qwen3MLP"
,
"Qwen2MLP"
,
"Qwen2MoeMLP"
,
"Qwen2_5_VLMLP"
,
"Qwen2_5OmniMLP"
,
"Llama4TextMLP"
,
"LlamaMLP"
,
"Glm4MLP"
,
"Glm4MoeMLP"
,
"Glm4vMoeTextMLP"
,
"Gemma3MLP"
,
"Gemma2MLP"
,
"Gemma3nTextMLP"
,
"Phi3MLP"
,
"DeepseekV2MLP"
,
"DeepseekV3MLP"
,
"SeedOssMLP"
,
}
)
_kernel_id
=
"npu_fused_swiglu"
_device
=
DeviceType
.
NPU
@
classmethod
def
apply
(
cls
,
**
kwargs
)
->
"HFModel"
:
r
"""Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched SwiGLU forward functions.
Raises:
ValueError: If the model is not provided.
RuntimeError: If dependencies are not met.
"""
model
=
kwargs
.
get
(
"model"
,
None
)
if
model
is
None
:
raise
ValueError
(
f
"HFModel instance is required for
{
cls
.
__name__
}
."
)
if
not
cls
.
check_deps
():
raise
RuntimeError
(
"torch_npu is not available but NpuSwiGluKernel was called."
)
# Mapping of specific mlp modules to their corresponding kernel implementations
kernel_mapping
=
{
"Glm4MLP"
:
_npu_swiglu_glm4_forward
,
"Glm4vTextMLP"
:
_npu_swiglu_glm4_forward
,
"Phi3MLP"
:
_npu_swiglu_glm4_forward
,
"Gemma3nTextMLP"
:
_npu_swiglu_gemma3ntext_forward
,
}
swiglu_pattern
=
re
.
compile
(
"MLP"
,
re
.
IGNORECASE
)
for
name
,
module
in
model
.
named_modules
():
# Match any module whose class name contains "MLP"
if
(
re
.
search
(
swiglu_pattern
,
module
.
__class__
.
__name__
)
and
module
.
__class__
.
__name__
in
cls
.
expect_modules
):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
kernel_func
=
kernel_mapping
.
get
(
module
.
__class__
.
__name__
,
npu_swiglu_forward
)
module
.
forward
=
types
.
MethodType
(
kernel_func
,
module
)
return
model
src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RMSNorm kernels.
Init Phase:
1. Define RMSNorm forward function.
2. Register NPU fused RMSNorm kernel.
"""
import
re
import
types
from
......accelerator.helper
import
DeviceType
from
......utils.types
import
HFModel
from
...base
import
BaseKernel
from
...registry
import
register_kernel
def
npu_rms_norm_forward
(
self
,
hidden_states
):
r
"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
hidden_states (Tensor): Input hidden states tensor, same shape as the baseline.
Returns:
Tensor: Normalized tensor consistent with the baseline RMSNorm behavior.
"""
import
torch_npu
return
torch_npu
.
npu_rms_norm
(
hidden_states
,
self
.
weight
,
epsilon
=
self
.
variance_epsilon
)[
0
]
@
register_kernel
class
NpuRMSNormKernel
(
BaseKernel
):
r
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
_kernel_id
=
"npu_fused_rmsnorm"
_device
=
DeviceType
.
NPU
@
classmethod
def
apply
(
cls
,
**
kwargs
)
->
"HFModel"
:
r
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
- Bind `_npu_rms_forward` as an instance method via `types.MethodType` to
replace the original `forward`.
- Do not modify weights, hyperparameters, or module structure to ensure
numerical behavior and interface consistency.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with NPU fused RMSNorm.
Raises:
RuntimeError: If torch_npu is not available.
ValueError: If the model is not provided.
"""
model
=
kwargs
.
get
(
"model"
)
if
model
is
None
:
raise
ValueError
(
f
"HFModel instance is required for
{
cls
.
__name__
}
."
)
if
not
cls
.
check_deps
():
raise
RuntimeError
(
f
"torch_npu is not available but
{
cls
.
__name__
}
was called."
)
rms_norm_pattern
=
re
.
compile
(
"RMSNorm"
,
re
.
IGNORECASE
)
for
name
,
module
in
model
.
named_modules
():
# Match any module whose class name contains "RMSNorm"
if
re
.
search
(
rms_norm_pattern
,
module
.
__class__
.
__name__
):
# Bind function as an instance method to preserve `self` semantics
# and replace the original forward
module
.
forward
=
types
.
MethodType
(
npu_rms_norm_forward
,
module
)
return
model
src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of NPU fused RoPE kernels.
Init Phase:
1. Define RoPE forward functions.
2. Register NPU fused RoPE kernel.
"""
import
sys
import
torch
from
......accelerator.helper
import
DeviceType
from
......utils.logging
import
get_logger
from
......utils.types
import
HFModel
from
...base
import
BaseKernel
from
...registry
import
register_kernel
logger
=
get_logger
(
__name__
)
try
:
import
torch_npu
except
ImportError
:
pass
def
_apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
r
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
position_ids (Tensor, optional): Position IDs. Default: ``None``.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
torch_npu
.
npu_rotary_mul
(
q
,
cos
,
sin
)
k_embed
=
torch_npu
.
npu_rotary_mul
(
k
,
cos
,
sin
)
return
q_embed
,
k_embed
def
_apply_multimodal_rotary_pos_emb_qwen25_vl
(
q
,
k
,
cos
,
sin
,
mrope_section
,
unsqueeze_dim
=
1
):
r
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
k (Tensor): Key tensor.
cos (Tensor): Cosine part of embedding.
sin (Tensor): Sine part of embedding.
mrope_section (Tensor): Multimodal RoPE section.
unsqueeze_dim (int): Dimension to unsqueeze cos and sin. Default: 1.
Returns:
tuple: (q_embed, k_embed) The embedded query and key tensors.
"""
mrope_section
=
mrope_section
*
2
cos
=
torch
.
cat
([
m
[
i
%
3
]
for
i
,
m
in
enumerate
(
cos
.
split
(
mrope_section
,
dim
=-
1
))],
dim
=-
1
).
unsqueeze
(
unsqueeze_dim
)
sin
=
torch
.
cat
([
m
[
i
%
3
]
for
i
,
m
in
enumerate
(
sin
.
split
(
mrope_section
,
dim
=-
1
))],
dim
=-
1
).
unsqueeze
(
unsqueeze_dim
)
q_embed
=
torch_npu
.
npu_rotary_mul
(
q
,
cos
,
sin
)
k_embed
=
torch_npu
.
npu_rotary_mul
(
k
,
cos
,
sin
)
return
q_embed
,
k_embed
@
register_kernel
class
NpuRoPEKernel
(
BaseKernel
):
r
"""NPU Kernel for Rotary Position Embedding."""
_kernel_id
=
"npu_fused_rope"
_device
=
DeviceType
.
NPU
@
classmethod
def
apply
(
cls
,
**
kwargs
)
->
"HFModel"
:
r
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
`apply_rotary_pos_emb` function in that module's namespace with the
NPU-accelerated version from this file.
Args:
**kwargs: Keyword arguments containing the model.
Returns:
HFModel: The model with patched RoPE functions.
Raises:
RuntimeError: If dependencies are not met.
ValueError: If the model is not provided.
"""
if
not
cls
.
check_deps
():
raise
RuntimeError
(
f
"torch_npu is not available but
{
cls
.
__name__
}
was called."
)
model
=
kwargs
.
get
(
"model"
,
None
)
if
model
is
None
:
raise
ValueError
(
f
"HFModel instance is required for
{
cls
.
__name__
}
."
)
_modules
=
set
()
for
module
in
model
.
modules
():
if
"Attention"
in
module
.
__class__
.
__name__
:
module_name
=
module
.
__class__
.
__module__
if
module_name
in
_modules
:
continue
try
:
target_module
=
sys
.
modules
[
module_name
]
if
hasattr
(
target_module
,
"apply_rotary_pos_emb"
):
if
getattr
(
target_module
,
"apply_rotary_pos_emb"
)
is
not
_apply_rotary_pos_emb
:
setattr
(
target_module
,
"apply_rotary_pos_emb"
,
_apply_rotary_pos_emb
)
_modules
.
add
(
module_name
)
if
hasattr
(
target_module
,
"apply_multimodal_rotary_pos_emb"
):
if
(
getattr
(
target_module
,
"apply_multimodal_rotary_pos_emb"
)
is
not
_apply_multimodal_rotary_pos_emb_qwen25_vl
):
setattr
(
target_module
,
"apply_multimodal_rotary_pos_emb"
,
_apply_multimodal_rotary_pos_emb_qwen25_vl
,
)
_modules
.
add
(
module_name
)
except
Exception
as
e
:
logger
.
warning_rank0_once
(
f
"Failed to apply RoPE kernel to module
{
module_name
}
:
{
e
}
"
)
return
model
src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of kernel registry.
Init Phase:
1. Define kernel registry.
2. Register kernels.
"""
from
typing
import
Optional
from
....accelerator.helper
import
get_current_accelerator
from
.base
import
BaseKernel
__all__
=
[
"Registry"
,
"register_kernel"
]
class
Registry
:
r
"""Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }``
"""
_kernels
:
dict
[
str
,
type
[
BaseKernel
]]
=
{}
@
classmethod
def
register
(
cls
,
kernel_cls
:
type
[
BaseKernel
]):
r
"""Decorator to register a kernel class.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
Args:
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
type[BaseKernel]: The registered kernel class.
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
ValueError: If the kernel ID is missing or already registered.
"""
if
not
issubclass
(
kernel_cls
,
BaseKernel
):
raise
TypeError
(
f
"Class
{
kernel_cls
}
must inherit from BaseKernel"
)
kernel_id
=
kernel_cls
.
get_kernel_id
()
device
=
kernel_cls
.
get_device
()
# The device type of the current accelerator does not match the device type required by the kernel, skip registration
if
device
!=
get_current_accelerator
().
type
:
return
if
not
kernel_id
:
raise
ValueError
(
f
"Kernel ID (_kernel_id) is needed for
{
kernel_cls
}
to register"
)
if
kernel_id
in
cls
.
_kernels
:
raise
ValueError
(
f
"
{
kernel_id
}
already registered! The registered kernel is
{
cls
.
_kernels
[
kernel_id
]
}
"
)
cls
.
_kernels
[
kernel_id
]
=
kernel_cls
return
kernel_cls
@
classmethod
def
get
(
cls
,
kernel_id
:
str
)
->
Optional
[
type
[
BaseKernel
]]:
r
"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
Returns:
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
"""
return
cls
.
_kernels
.
get
(
kernel_id
)
@
classmethod
def
get_registered_kernels
(
cls
)
->
dict
[
str
,
type
[
BaseKernel
]]:
r
"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
"""
return
cls
.
_kernels
# export decorator alias
register_kernel
=
Registry
.
register
src/llamafactory/v1/plugins/model_plugins/peft.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Literal
,
TypedDict
from
peft
import
LoraConfig
,
PeftModel
,
get_peft_model
from
...utils.plugin
import
BasePlugin
from
...utils.types
import
HFModel
class
LoraConfigDict
(
TypedDict
,
total
=
False
):
name
:
Literal
[
"lora"
]
"""Plugin name."""
r
:
int
"""Lora rank."""
lora_alpha
:
int
"""Lora alpha."""
target_modules
:
list
[
str
]
"""Target modules."""
class
FreezeConfigDict
(
TypedDict
,
total
=
False
):
name
:
Literal
[
"freeze"
]
"""Plugin name."""
freeze_trainable_layers
:
int
"""Freeze trainable layers."""
freeze_trainable_modules
:
list
[
str
]
|
None
"""Freeze trainable modules."""
class
PeftPlugin
(
BasePlugin
):
def
__call__
(
self
,
model
:
HFModel
,
config
:
dict
,
is_train
:
bool
)
->
HFModel
:
return
super
().
__call__
(
model
,
config
)
@
PeftPlugin
(
"lora"
).
register
def
get_lora_model
(
model
:
HFModel
,
config
:
LoraConfigDict
,
is_train
:
bool
)
->
PeftModel
:
peft_config
=
LoraConfig
(
**
config
)
model
=
get_peft_model
(
model
,
peft_config
)
return
model
@
PeftPlugin
(
"freeze"
).
register
def
get_freeze_model
(
model
:
HFModel
,
config
:
FreezeConfigDict
,
is_train
:
bool
)
->
HFModel
:
raise
NotImplementedError
()
src/llamafactory/v1/plugins/model_plugins/quantization.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/model_plugins/rendering.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
from
...utils
import
logging
from
...utils.plugin
import
BasePlugin
from
...utils.types
import
Message
,
ModelInput
,
Processor
logger
=
logging
.
get_logger
(
__name__
)
class
RenderingPlugin
(
BasePlugin
):
_attempted_template_imports
:
set
[
str
]
=
set
()
def
_ensure_template_imported
(
self
)
->
None
:
if
self
.
name
is
None
or
self
.
name
in
self
.
_attempted_template_imports
:
return
full_module_name
=
f
"
{
__package__
}
.templates.
{
self
.
name
}
"
self
.
_attempted_template_imports
.
add
(
self
.
name
)
try
:
importlib
.
import_module
(
full_module_name
)
except
Exception
as
exc
:
logger
.
warning
(
f
"[Template Registry] Failed to import
{
full_module_name
}
:
{
exc
}
"
)
def
__getitem__
(
self
,
method_name
:
str
):
self
.
_ensure_template_imported
()
return
super
().
__getitem__
(
method_name
)
def
render_messages
(
self
,
processor
:
Processor
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
,
is_generate
:
bool
=
False
,
enable_thinking
:
bool
=
False
,
)
->
ModelInput
:
"""Render messages in the template format."""
return
self
[
"render_messages"
](
processor
,
messages
,
tools
,
is_generate
,
enable_thinking
)
def
parse_messages
(
self
,
generated_text
:
str
)
->
Message
:
"""Parse messages in the template format."""
return
self
[
"parse_messages"
](
generated_text
)
src/llamafactory/v1/plugins/model_plugins/templates/__init__.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
re
from
....utils.constants
import
IGNORE_INDEX
from
....utils.helper
import
get_tokenizer
from
....utils.types
import
Message
,
ModelInput
,
Processor
,
ToolCall
from
..rendering
import
RenderingPlugin
def
_update_model_input
(
processor
:
Processor
,
input_ids
:
list
[
int
],
labels
:
list
[
int
],
loss_weights
:
list
[
int
],
temp_str
:
str
,
temp_weight
:
float
,
)
->
str
:
"""Update model input with temporary string."""
if
not
temp_str
:
return
""
tokenizer
=
get_tokenizer
(
processor
)
temp_ids
=
tokenizer
.
encode
(
temp_str
,
add_special_tokens
=
False
)
input_ids
.
extend
(
temp_ids
)
loss_weights
.
extend
([
temp_weight
]
*
len
(
temp_ids
))
if
temp_weight
>
1e-6
:
labels
.
extend
(
temp_ids
)
else
:
labels
.
extend
([
IGNORE_INDEX
]
*
len
(
temp_ids
))
return
""
def
_concat_text_content
(
message
:
Message
)
->
str
:
"""Concatenate text fields in a message."""
message_text
=
""
for
content
in
message
[
"content"
]:
if
content
[
"type"
]
==
"text"
:
message_text
+=
content
[
"value"
]
else
:
raise
ValueError
(
f
"Unsupported content type:
{
content
[
'type'
]
}
"
)
return
message_text
def
_get_last_query_index
(
messages
:
list
[
Message
])
->
int
:
"""Find the last user query index, excluding wrapped tool responses."""
last_query_index
=
len
(
messages
)
-
1
for
idx
in
range
(
len
(
messages
)
-
1
,
-
1
,
-
1
):
message
=
messages
[
idx
]
if
message
[
"role"
]
!=
"user"
:
continue
user_text
=
""
is_plain_text
=
True
for
content
in
message
[
"content"
]:
if
content
[
"type"
]
!=
"text"
:
is_plain_text
=
False
break
user_text
+=
content
[
"value"
]
if
not
is_plain_text
:
continue
if
not
(
user_text
.
startswith
(
"<tool_response>"
)
and
user_text
.
endswith
(
"</tool_response>"
)):
last_query_index
=
idx
break
return
last_query_index
def
_split_assistant_content
(
message
:
Message
)
->
tuple
[
str
,
str
,
list
[
ToolCall
]]:
"""Split assistant message into text, reasoning and tool calls."""
text_content
=
""
reasoning_content
=
""
tool_calls
:
list
[
ToolCall
]
=
[]
for
content
in
message
[
"content"
]:
if
content
[
"type"
]
==
"text"
:
text_content
+=
content
[
"value"
]
elif
content
[
"type"
]
==
"reasoning"
:
reasoning_content
+=
content
[
"value"
]
elif
content
[
"type"
]
==
"tool_call"
:
try
:
tool_call
:
ToolCall
=
json
.
loads
(
content
[
"value"
])
except
json
.
JSONDecodeError
:
raise
ValueError
(
f
"Invalid tool call format:
{
content
[
'value'
]
}
."
)
tool_calls
.
append
(
tool_call
)
else
:
raise
ValueError
(
f
"Unsupported content type:
{
content
[
'type'
]
}
"
)
return
text_content
,
reasoning_content
,
tool_calls
@
RenderingPlugin
(
"qwen3"
).
register
(
"render_messages"
)
def
render_qwen3_messages
(
processor
:
Processor
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
,
is_generate
:
bool
=
False
,
enable_thinking
:
bool
=
False
,
)
->
ModelInput
:
"""Render messages in the Qwen3 template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
"""
input_ids
,
labels
,
loss_weights
=
[],
[],
[]
temp_str
,
temp_weight
=
""
,
0.0
if
tools
:
temp_str
+=
"<|im_start|>system
\n
"
if
messages
[
0
][
"role"
]
==
"system"
:
temp_str
+=
_concat_text_content
(
messages
[
0
])
+
"
\n\n
"
temp_weight
=
messages
[
0
].
get
(
"loss_weight"
,
0.0
)
temp_str
+=
(
"# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>"
)
try
:
tools
=
json
.
loads
(
tools
)
except
json
.
JSONDecodeError
:
raise
ValueError
(
f
"Invalid tools format:
{
str
(
tools
)
}
."
)
if
not
isinstance
(
tools
,
list
):
tools
=
[
tools
]
for
tool
in
tools
:
temp_str
+=
"
\n
"
+
json
.
dumps
(
tool
,
ensure_ascii
=
False
)
temp_str
+=
(
"
\n
</tools>
\n\n
For each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{"name": '
'<function-name>, "arguments": <args-json-object>}
\n
</tool_call><|im_end|>
\n
'
)
elif
messages
[
0
][
"role"
]
==
"system"
:
temp_str
+=
"<|im_start|>system
\n
"
+
_concat_text_content
(
messages
[
0
])
+
"<|im_end|>
\n
"
temp_weight
=
messages
[
0
].
get
(
"loss_weight"
,
0.0
)
temp_str
=
_update_model_input
(
processor
,
input_ids
,
labels
,
loss_weights
,
temp_str
,
temp_weight
)
last_query_index
=
_get_last_query_index
(
messages
)
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
"role"
]
==
"user"
or
(
message
[
"role"
]
==
"system"
and
turn_idx
!=
0
):
temp_str
+=
"<|im_start|>"
+
message
[
"role"
]
+
"
\n
"
+
_concat_text_content
(
message
)
+
"<|im_end|>
\n
"
temp_weight
=
message
.
get
(
"loss_weight"
,
0.0
)
elif
message
[
"role"
]
==
"assistant"
:
temp_str
+=
"<|im_start|>"
+
message
[
"role"
]
+
"
\n
"
text_content
,
reasoning_content
,
tool_calls
=
_split_assistant_content
(
message
)
if
turn_idx
>
last_query_index
and
(
turn_idx
==
len
(
messages
)
-
1
or
reasoning_content
):
temp_str
+=
"<think>
\n
"
+
reasoning_content
.
strip
(
"
\n
"
)
+
"
\n
</think>
\n\n
"
+
text_content
.
lstrip
(
"
\n
"
)
else
:
temp_str
+=
text_content
for
tool_call_idx
,
tool_call
in
enumerate
(
tool_calls
):
if
(
tool_call_idx
==
0
and
text_content
)
or
tool_call_idx
>
0
:
temp_str
+=
"
\n
"
arguments
=
tool_call
.
get
(
"arguments"
)
if
isinstance
(
arguments
,
str
):
arguments_str
=
arguments
else
:
arguments_str
=
json
.
dumps
(
arguments
,
ensure_ascii
=
False
)
temp_str
+=
(
'<tool_call>
\n
{"name": "'
+
tool_call
[
"name"
]
+
'", "arguments": '
+
arguments_str
+
"}
\n
</tool_call>"
)
temp_str
+=
"<|im_end|>
\n
"
temp_weight
=
message
.
get
(
"loss_weight"
,
1.0
)
elif
message
[
"role"
]
==
"tool"
:
if
turn_idx
==
0
or
messages
[
turn_idx
-
1
][
"role"
]
!=
"tool"
:
temp_str
+=
"<|im_start|>user"
temp_str
+=
"
\n
<tool_response>
\n
"
+
_concat_text_content
(
message
)
+
"
\n
</tool_response>"
if
turn_idx
==
len
(
messages
)
-
1
or
messages
[
turn_idx
+
1
][
"role"
]
!=
"tool"
:
temp_str
+=
"<|im_end|>
\n
"
temp_weight
=
message
.
get
(
"loss_weight"
,
0.0
)
temp_str
=
_update_model_input
(
processor
,
input_ids
,
labels
,
loss_weights
,
temp_str
,
temp_weight
)
if
is_generate
:
temp_str
+=
"<|im_start|>assistant
\n
"
temp_weight
=
0.0
if
enable_thinking
is
False
:
temp_str
+=
"<think>
\n\n
</think>
\n\n
"
temp_str
=
_update_model_input
(
processor
,
input_ids
,
labels
,
loss_weights
,
temp_str
,
temp_weight
)
attention_mask
=
[
1
]
*
len
(
input_ids
)
return
ModelInput
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
,
loss_weights
=
loss_weights
,
)
@
RenderingPlugin
(
"qwen3"
).
register
(
"parse_message"
)
def
parse_qwen3_message
(
generated_text
:
str
)
->
Message
:
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 template format.
Returns:
Message: The parsed message.
"""
pattern
=
re
.
compile
(
r
"<(think|tool_call)>\s*(.*?)\s*</\1>\s*"
,
re
.
DOTALL
)
content
=
[]
last_end
=
0
for
match
in
pattern
.
finditer
(
generated_text
):
start
,
end
=
match
.
span
()
if
start
>
last_end
:
text
=
generated_text
[
last_end
:
start
].
strip
()
if
text
:
content
.
append
({
"type"
:
"text"
,
"value"
:
text
})
tag_type
=
match
.
group
(
1
)
tag_value
=
match
.
group
(
2
).
strip
()
if
tag_type
==
"think"
:
content
.
append
({
"type"
:
"reasoning"
,
"value"
:
tag_value
.
strip
()})
elif
tag_type
==
"tool_call"
:
try
:
json
.
loads
(
tag_value
.
strip
())
except
json
.
JSONDecodeError
:
raise
ValueError
(
f
"Invalid tool call format:
{
tag_value
.
strip
()
}
."
)
content
.
append
({
"type"
:
"tool_call"
,
"value"
:
tag_value
.
strip
()})
last_end
=
end
if
last_end
<
len
(
generated_text
):
text
=
generated_text
[
last_end
:].
strip
()
if
text
:
content
.
append
({
"type"
:
"text"
,
"value"
:
text
})
return
Message
(
role
=
"assistant"
,
content
=
content
)
src/llamafactory/v1/plugins/model_plugins/templates/qwen3_nothink.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
re
from
....utils.constants
import
IGNORE_INDEX
from
....utils.helper
import
get_tokenizer
from
....utils.types
import
Message
,
ModelInput
,
Processor
,
ToolCall
from
..rendering
import
RenderingPlugin
def
_update_model_input
(
processor
:
Processor
,
input_ids
:
list
[
int
],
labels
:
list
[
int
],
loss_weights
:
list
[
int
],
temp_str
:
str
,
temp_weight
:
float
,
)
->
str
:
"""Update model input with temporary string."""
if
not
temp_str
:
return
""
tokenizer
=
get_tokenizer
(
processor
)
temp_ids
=
tokenizer
.
encode
(
temp_str
,
add_special_tokens
=
False
)
input_ids
.
extend
(
temp_ids
)
loss_weights
.
extend
([
temp_weight
]
*
len
(
temp_ids
))
if
temp_weight
>
1e-6
:
labels
.
extend
(
temp_ids
)
else
:
labels
.
extend
([
IGNORE_INDEX
]
*
len
(
temp_ids
))
return
""
def
_concat_text_content
(
message
:
Message
)
->
str
:
"""Concatenate text fields in a message."""
message_text
=
""
for
content
in
message
[
"content"
]:
if
content
[
"type"
]
==
"text"
:
message_text
+=
content
[
"value"
]
else
:
raise
ValueError
(
f
"Unsupported content type:
{
content
[
'type'
]
}
"
)
return
message_text
@
RenderingPlugin
(
"qwen3_nothink"
).
register
(
"render_messages"
)
def
render_qwen3_nothink_messages
(
processor
:
Processor
,
messages
:
list
[
Message
],
tools
:
str
|
None
=
None
,
is_generate
:
bool
=
False
,
enable_thinking
:
bool
=
False
,
)
->
ModelInput
:
"""Render messages in the Qwen3 nothink template format.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
"""
input_ids
,
labels
,
loss_weights
=
[],
[],
[]
temp_str
,
temp_weight
=
""
,
0.0
if
tools
:
temp_str
+=
"<|im_start|>system
\n
"
if
messages
[
0
][
"role"
]
==
"system"
:
temp_str
+=
_concat_text_content
(
messages
[
0
])
+
"
\n\n
"
temp_weight
=
messages
[
0
].
get
(
"loss_weight"
,
0.0
)
temp_str
+=
(
"# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>"
)
try
:
tools
=
json
.
loads
(
tools
)
except
json
.
JSONDecodeError
:
raise
ValueError
(
f
"Invalid tools format:
{
str
(
tools
)
}
."
)
if
not
isinstance
(
tools
,
list
):
tools
=
[
tools
]
for
tool
in
tools
:
temp_str
+=
"
\n
"
+
json
.
dumps
(
tool
,
ensure_ascii
=
False
)
temp_str
+=
(
"
\n
</tools>
\n\n
For each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{"name": '
'<function-name>, "arguments": <args-json-object>}
\n
</tool_call><|im_end|>
\n
'
)
elif
messages
[
0
][
"role"
]
==
"system"
:
temp_str
+=
"<|im_start|>system
\n
"
+
_concat_text_content
(
messages
[
0
])
+
"<|im_end|>
\n
"
temp_weight
=
messages
[
0
].
get
(
"loss_weight"
,
0.0
)
temp_str
=
_update_model_input
(
processor
,
input_ids
,
labels
,
loss_weights
,
temp_str
,
temp_weight
)
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
"role"
]
==
"user"
or
(
message
[
"role"
]
==
"system"
and
turn_idx
!=
0
):
temp_str
+=
"<|im_start|>"
+
message
[
"role"
]
+
"
\n
"
+
_concat_text_content
(
message
)
+
"<|im_end|>
\n
"
temp_weight
=
message
.
get
(
"loss_weight"
,
0.0
)
elif
message
[
"role"
]
==
"assistant"
:
temp_str
+=
"<|im_start|>"
+
message
[
"role"
]
+
"
\n
"
for
val_idx
,
content
in
enumerate
(
message
[
"content"
]):
if
content
[
"type"
]
==
"text"
:
temp_str
+=
content
[
"value"
]
elif
content
[
"type"
]
==
"reasoning"
:
temp_str
+=
"<thinking>
\n
"
+
content
[
"value"
]
+
"
\n
</thinking>
\n\n
"
# avoid using special tokens
elif
content
[
"type"
]
==
"tool_call"
:
if
val_idx
!=
0
and
message
[
"content"
][
val_idx
-
1
][
"type"
]
in
[
"text"
,
"tool_call"
]:
temp_str
+=
"
\n
"
try
:
tool_call
:
ToolCall
=
json
.
loads
(
content
[
"value"
])
except
json
.
JSONDecodeError
:
raise
ValueError
(
f
"Invalid tool call format:
{
content
[
'value'
]
}
."
)
temp_str
+=
(
'<tool_call>
\n
{"name": "'
+
tool_call
[
"name"
]
+
'", "arguments": '
+
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)
+
"}
\n
</tool_call>"
)
else
:
raise
ValueError
(
f
"Unsupported content type:
{
content
[
'type'
]
}
"
)
temp_str
+=
"<|im_end|>
\n
"
temp_weight
=
message
.
get
(
"loss_weight"
,
1.0
)
elif
message
[
"role"
]
==
"tool"
:
if
turn_idx
==
0
or
messages
[
turn_idx
-
1
][
"role"
]
!=
"tool"
:
temp_str
+=
"<|im_start|>user"
temp_str
+=
"
\n
<tool_response>
\n
"
+
_concat_text_content
(
message
)
+
"
\n
</tool_response>"
if
turn_idx
==
len
(
messages
)
-
1
or
messages
[
turn_idx
+
1
][
"role"
]
!=
"tool"
:
temp_str
+=
"<|im_end|>
\n
"
temp_weight
=
message
.
get
(
"loss_weight"
,
0.0
)
temp_str
=
_update_model_input
(
processor
,
input_ids
,
labels
,
loss_weights
,
temp_str
,
temp_weight
)
if
is_generate
:
temp_str
+=
"<|im_start|>assistant
\n
"
temp_weight
=
0.0
if
enable_thinking
:
raise
ValueError
(
"The qwen3_nothink template does not support thinking mode."
)
temp_str
=
_update_model_input
(
processor
,
input_ids
,
labels
,
loss_weights
,
temp_str
,
temp_weight
)
attention_mask
=
[
1
]
*
len
(
input_ids
)
return
ModelInput
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
,
loss_weights
=
loss_weights
,
)
@
RenderingPlugin
(
"qwen3_nothink"
).
register
(
"parse_message"
)
def
parse_qwen3_nothink_message
(
generated_text
:
str
)
->
Message
:
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in the Qwen3 nothink template format.
Returns:
Message: The parsed message.
"""
pattern
=
re
.
compile
(
r
"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*"
,
re
.
DOTALL
)
content
=
[]
last_end
=
0
for
match
in
pattern
.
finditer
(
generated_text
):
start
,
end
=
match
.
span
()
if
start
>
last_end
:
text
=
generated_text
[
last_end
:
start
].
strip
()
if
text
:
content
.
append
({
"type"
:
"text"
,
"value"
:
text
})
tag_type
=
match
.
group
(
1
)
tag_value
=
match
.
group
(
2
).
strip
()
if
tag_type
==
"thinking"
:
content
.
append
({
"type"
:
"reasoning"
,
"value"
:
tag_value
.
strip
()})
elif
tag_type
==
"tool_call"
:
try
:
json
.
loads
(
tag_value
.
strip
())
except
json
.
JSONDecodeError
:
raise
ValueError
(
f
"Invalid tool call format:
{
tag_value
.
strip
()
}
."
)
content
.
append
({
"type"
:
"tool_call"
,
"value"
:
tag_value
.
strip
()})
last_end
=
end
if
last_end
<
len
(
generated_text
):
text
=
generated_text
[
last_end
:].
strip
()
if
text
:
content
.
append
({
"type"
:
"text"
,
"value"
:
text
})
return
Message
(
role
=
"assistant"
,
content
=
content
)
src/llamafactory/v1/plugins/sampler_plugins/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/sampler_plugins/vllm.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/trainer_plugins/__init__.py
0 → 100644
View file @
ca625f43
src/llamafactory/v1/plugins/trainer_plugins/batching.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
...utils.objects
import
StatefulBuffer
from
...utils.plugin
import
BasePlugin
from
...utils.types
import
BatchInfo
,
BatchInput
,
DataLoader
class
BatchingPlugin
(
BasePlugin
):
def
compute_length
(
self
,
data_provider
:
DataLoader
)
->
int
:
"""Compute the length of the batch generator.
The approximate length is used to calculate the lr schedule.
"""
raise
NotImplementedError
()
def
fill_buffer
(
self
,
buffer
:
StatefulBuffer
,
batch_info
:
BatchInfo
)
->
None
:
"""Fill the buffer with data."""
raise
NotImplementedError
()
def
generate_batch
(
self
,
buffer
:
StatefulBuffer
,
batch_info
:
BatchInfo
)
->
list
[
BatchInput
]
|
None
:
"""Generate a batch from the buffer."""
raise
NotImplementedError
()
Prev
1
…
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment