Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2149 additions
and
376 deletions
+2149
-376
src/llamafactory/chat/chat_model.py
src/llamafactory/chat/chat_model.py
+31
-5
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+2
-2
src/llamafactory/chat/kt_engine.py
src/llamafactory/chat/kt_engine.py
+284
-0
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+9
-1
src/llamafactory/cli.py
src/llamafactory/cli.py
+6
-135
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+13
-4
src/llamafactory/data/converter.py
src/llamafactory/data/converter.py
+144
-3
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+19
-12
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+43
-26
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+19
-18
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+274
-48
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+30
-28
src/llamafactory/data/processor/supervised.py
src/llamafactory/data/processor/supervised.py
+1
-1
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+376
-18
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+248
-0
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+577
-41
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+31
-21
src/llamafactory/extras/logging.py
src/llamafactory/extras/logging.py
+1
-1
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+16
-8
src/llamafactory/extras/packages.py
src/llamafactory/extras/packages.py
+25
-4
No files found.
src/llamafactory/chat/chat_model.py
View file @
ca625f43
...
...
@@ -24,9 +24,6 @@ from typing import TYPE_CHECKING, Any, Optional
from
..extras.constants
import
EngineName
from
..extras.misc
import
torch_gc
from
..hparams
import
get_infer_args
from
.hf_engine
import
HuggingfaceEngine
from
.sglang_engine
import
SGLangEngine
from
.vllm_engine
import
VllmEngine
if
TYPE_CHECKING
:
...
...
@@ -49,12 +46,41 @@ class ChatModel:
def
__init__
(
self
,
args
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
if
model_args
.
infer_backend
==
EngineName
.
HF
:
from
.hf_engine
import
HuggingfaceEngine
self
.
engine
:
BaseEngine
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
EngineName
.
VLLM
:
self
.
engine
:
BaseEngine
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
try
:
from
.vllm_engine
import
VllmEngine
self
.
engine
:
BaseEngine
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
except
ImportError
as
e
:
raise
ImportError
(
"vLLM not install, you may need to run `pip install vllm`
\n
"
"or try to use HuggingFace backend: --infer_backend huggingface"
)
from
e
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
self
.
engine
:
BaseEngine
=
SGLangEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
try
:
from
.sglang_engine
import
SGLangEngine
self
.
engine
:
BaseEngine
=
SGLangEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
except
ImportError
as
e
:
raise
ImportError
(
"SGLang not install, you may need to run `pip install sglang[all]`
\n
"
"or try to use HuggingFace backend: --infer_backend huggingface"
)
from
e
elif
model_args
.
infer_backend
==
EngineName
.
KT
:
try
:
from
.kt_engine
import
KTransformersEngine
self
.
engine
:
BaseEngine
=
KTransformersEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
except
ImportError
as
e
:
raise
ImportError
(
"KTransformers not install, you may need to run `pip install ktransformers`
\n
"
"or try to use HuggingFace backend: --infer_backend huggingface"
)
from
e
else
:
raise
NotImplementedError
(
f
"Unknown backend:
{
model_args
.
infer_backend
}
"
)
...
...
src/llamafactory/chat/hf_engine.py
View file @
ca625f43
...
...
@@ -14,9 +14,9 @@
import
asyncio
import
os
from
collections.abc
import
AsyncGenerator
from
collections.abc
import
AsyncGenerator
,
Callable
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
torch
from
transformers
import
GenerationConfig
,
TextIteratorStreamer
...
...
src/llamafactory/chat/kt_engine.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
os
import
platform
from
collections.abc
import
AsyncGenerator
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
EngineName
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
from
trl
import
PreTrainedModelWrapper
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.server.config.config
import
Config
from
ktransformers.util.utils
import
(
get_compute_capability
,
prefill_and_generate_capture
,
)
from
ktransformers.util.vendors
import
GPUVendor
,
device_manager
logger
=
logging
.
get_logger
(
__name__
)
class
KTransformersEngine
(
BaseEngine
):
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
name
=
EngineName
.
KT
self
.
can_generate
=
finetuning_args
.
stage
==
"sft"
tok_mod
=
load_tokenizer
(
model_args
)
self
.
tokenizer
=
tok_mod
[
"tokenizer"
]
self
.
tokenizer
.
padding_side
=
"left"
if
self
.
can_generate
else
"right"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
model
=
load_model
(
self
.
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
(
not
self
.
can_generate
)
)
self
.
generating_args
=
generating_args
.
to_dict
()
self
.
max_new_tokens
=
model_args
.
kt_maxlen
self
.
use_cuda_graph
=
model_args
.
kt_use_cuda_graph
self
.
mode
=
model_args
.
kt_mode
self
.
force_think
=
model_args
.
kt_force_think
self
.
chunk_size
=
model_args
.
chunk_size
try
:
asyncio
.
get_event_loop
()
except
RuntimeError
:
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
self
.
semaphore
=
asyncio
.
Semaphore
(
int
(
os
.
getenv
(
"MAX_CONCURRENT"
,
"1"
)))
@
staticmethod
@
torch
.
inference_mode
()
def
_get_scores
(
model
:
"PreTrainedModelWrapper"
,
tokenizer
:
"PreTrainedTokenizer"
,
batch_input
:
list
[
str
],
input_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
{},
)
->
list
[
float
]:
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
device
=
getattr
(
model
.
pretrained_model
,
"device"
,
"cuda"
)
inputs
=
tokenizer
(
batch_input
,
padding
=
True
,
truncation
=
True
,
max_length
=
max_length
or
getattr
(
model
.
config
,
"max_position_embeddings"
,
1024
),
return_tensors
=
"pt"
,
add_special_tokens
=
False
,
).
to
(
device
)
values
:
torch
.
Tensor
=
model
(
**
inputs
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
scores
=
values
.
gather
(
dim
=-
1
,
index
=
(
inputs
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
return
scores
async
def
_generate
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
paired
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired
,
system
,
tools
)
prompt_len
=
len
(
prompt_ids
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
if
"max_new_tokens"
in
self
.
generating_args
:
max_tokens
=
int
(
self
.
generating_args
[
"max_new_tokens"
])
elif
"max_length"
in
self
.
generating_args
:
gl
=
int
(
self
.
generating_args
[
"max_length"
])
max_tokens
=
gl
-
prompt_len
if
gl
>
prompt_len
else
1
else
:
max_tokens
=
self
.
max_new_tokens
or
256
if
max_length
is
not
None
:
max_tokens
=
max
(
max_length
-
prompt_len
,
1
)
if
max_new_tokens
is
not
None
:
max_tokens
=
int
(
max_new_tokens
)
max_tokens
=
max
(
1
,
int
(
max_tokens
))
if
self
.
mode
==
"long_context"
:
max_len_cfg
=
Config
().
long_context_config
[
"max_seq_len"
]
need
=
prompt_len
+
max_tokens
assert
max_len_cfg
>
need
,
f
"please set max_seq_len >
{
need
}
in ~/.ktransformers/config.yaml"
device
=
next
(
self
.
model
.
parameters
()).
device
input_tensor
=
torch
.
tensor
([
prompt_ids
],
dtype
=
torch
.
long
,
device
=
device
)
if
self
.
force_think
:
think
=
torch
.
tensor
(
[
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
dtype
=
torch
.
long
,
device
=
device
)
input_tensor
=
torch
.
cat
([
input_tensor
,
think
],
dim
=
1
)
use_flashinfer
=
(
platform
.
system
()
!=
"Windows"
and
getattr
(
self
.
model
.
config
,
"architectures"
,
[
""
])[
0
]
in
{
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
}
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
and
device_manager
.
gpu_vendor
==
GPUVendor
.
NVIDIA
)
def
make_gen
():
if
use_flashinfer
:
return
prefill_and_generate_capture
(
self
.
model
,
self
.
tokenizer
,
input_tensor
,
max_tokens
,
self
.
use_cuda_graph
,
mode
=
self
.
mode
,
force_think
=
self
.
force_think
,
chunk_size
=
self
.
chunk_size
,
use_flashinfer_mla
=
True
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
getattr
(
self
.
model
.
config
,
"kv_lora_rank"
,
0
),
head_dim_kpe
=
getattr
(
self
.
model
.
config
,
"qk_rope_head_dim"
,
0
),
q_head_dim
=
getattr
(
self
.
model
.
config
,
"qk_rope_head_dim"
,
0
)
+
getattr
(
self
.
model
.
config
,
"qk_nope_head_dim"
,
0
),
echo_stream
=
False
,
)
else
:
return
prefill_and_generate_capture
(
self
.
model
,
self
.
tokenizer
,
input_tensor
,
max_tokens
,
self
.
use_cuda_graph
,
mode
=
self
.
mode
,
force_think
=
self
.
force_think
,
chunk_size
=
self
.
chunk_size
,
echo_stream
=
False
,
)
loop
=
asyncio
.
get_running_loop
()
q
:
asyncio
.
Queue
[
Optional
[
str
]]
=
asyncio
.
Queue
()
def
producer
():
try
:
gen
=
make_gen
()
if
hasattr
(
gen
,
"__aiter__"
):
async
def
drain_async
():
async
for
t
in
gen
:
loop
.
call_soon_threadsafe
(
q
.
put_nowait
,
t
if
isinstance
(
t
,
str
)
else
str
(
t
))
asyncio
.
run
(
drain_async
())
elif
hasattr
(
gen
,
"__iter__"
):
for
t
in
gen
:
loop
.
call_soon_threadsafe
(
q
.
put_nowait
,
t
if
isinstance
(
t
,
str
)
else
str
(
t
))
else
:
loop
.
call_soon_threadsafe
(
q
.
put_nowait
,
gen
if
isinstance
(
gen
,
str
)
else
str
(
gen
))
finally
:
loop
.
call_soon_threadsafe
(
q
.
put_nowait
,
None
)
Thread
(
target
=
producer
,
daemon
=
True
).
start
()
while
True
:
item
=
await
q
.
get
()
if
item
is
None
:
break
yield
item
@
override
async
def
chat
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
list
[
"Response"
]:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `chat`."
)
async
with
self
.
semaphore
:
produced
=
""
final_text
=
""
async
for
t
in
self
.
_generate
(
messages
,
system
,
tools
,
**
input_kwargs
):
delta
=
t
produced
=
produced
+
delta
if
delta
:
final_text
+=
delta
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}],
system
,
tools
)
return
[
Response
(
response_text
=
final_text
,
response_length
=
len
(
self
.
tokenizer
.
encode
(
final_text
,
add_special_tokens
=
False
)),
prompt_length
=
len
(
prompt_ids
),
finish_reason
=
"stop"
,
)
]
@
override
async
def
stream_chat
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `stream_chat`."
)
async
with
self
.
semaphore
:
produced
=
""
async
for
t
in
self
.
_generate
(
messages
,
system
,
tools
,
**
input_kwargs
):
delta
=
t
[
len
(
produced
)
:]
if
t
.
startswith
(
produced
)
else
t
produced
=
t
if
delta
:
yield
delta
@
override
async
def
get_scores
(
self
,
batch_input
:
list
[
str
],
**
input_kwargs
,
)
->
list
[
float
]:
if
self
.
can_generate
:
raise
ValueError
(
"Cannot get scores using an auto-regressive model."
)
args
=
(
self
.
model
,
self
.
tokenizer
,
batch_input
,
input_kwargs
)
async
with
self
.
semaphore
:
return
await
asyncio
.
to_thread
(
self
.
_get_scores
,
*
args
)
src/llamafactory/chat/vllm_engine.py
View file @
ca625f43
...
...
@@ -16,6 +16,7 @@ import uuid
from
collections.abc
import
AsyncGenerator
,
AsyncIterator
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
packaging
import
version
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
...
...
@@ -77,11 +78,18 @@ class VllmEngine(BaseEngine):
"tensor_parallel_size"
:
get_device_count
()
or
1
,
"gpu_memory_utilization"
:
model_args
.
vllm_gpu_util
,
"disable_log_stats"
:
True
,
"disable_log_requests"
:
True
,
"enforce_eager"
:
model_args
.
vllm_enforce_eager
,
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
import
vllm
if
version
.
parse
(
vllm
.
__version__
)
<=
version
.
parse
(
"0.10.0"
):
engine_args
[
"disable_log_requests"
]
=
True
else
:
engine_args
[
"enable_log_requests"
]
=
False
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
,
"audio"
:
2
}
...
...
src/llamafactory/cli.py
View file @
ca625f43
...
...
@@ -12,145 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
import
sys
from
copy
import
deepcopy
from
functools
import
partial
USAGE
=
(
"-"
*
70
+
"
\n
"
+
"| Usage: |
\n
"
+
"| llamafactory-cli api -h: launch an OpenAI-style API server |
\n
"
+
"| llamafactory-cli chat -h: launch a chat interface in CLI |
\n
"
+
"| llamafactory-cli eval -h: evaluate models |
\n
"
+
"| llamafactory-cli export -h: merge LoRA adapters and export model |
\n
"
+
"| llamafactory-cli train -h: train models |
\n
"
+
"| llamafactory-cli webchat -h: launch a chat interface in Web UI |
\n
"
+
"| llamafactory-cli webui: launch LlamaBoard |
\n
"
+
"| llamafactory-cli version: show version info |
\n
"
+
"-"
*
70
)
def
main
():
from
.
import
launcher
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
logger
=
logging
.
get_logger
(
__name__
)
WELCOME
=
(
"-"
*
58
+
"
\n
"
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
" "
*
56
+
"|
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"-"
*
58
)
COMMAND_MAP
=
{
"api"
:
run_api
,
"chat"
:
run_chat
,
"env"
:
print_env
,
"eval"
:
run_eval
,
"export"
:
export_model
,
"train"
:
run_exp
,
"webchat"
:
run_web_demo
,
"webui"
:
run_web_ui
,
"version"
:
partial
(
print
,
WELCOME
),
"help"
:
partial
(
print
,
USAGE
),
}
from
.extras.misc
import
is_env_enabled
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
1
else
"help"
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
())):
# launch distributed training
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
find_available_port
()))
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
if
int
(
nnodes
)
>
1
:
logger
.
info_rank0
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
# elastic launch support
max_restarts
=
os
.
getenv
(
"MAX_RESTARTS"
,
"0"
)
rdzv_id
=
os
.
getenv
(
"RDZV_ID"
)
min_nnodes
=
os
.
getenv
(
"MIN_NNODES"
)
max_nnodes
=
os
.
getenv
(
"MAX_NNODES"
)
env
=
deepcopy
(
os
.
environ
)
if
is_env_enabled
(
"OPTIM_TORCH"
,
"1"
):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env
[
"PYTORCH_CUDA_ALLOC_CONF"
]
=
"expandable_segments:True"
env
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
if
rdzv_id
is
not
None
:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes
=
nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if
min_nnodes
is
not
None
and
max_nnodes
is
not
None
:
rdzv_nnodes
=
f
"
{
min_nnodes
}
:
{
max_nnodes
}
"
process
=
subprocess
.
run
(
(
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
"--max-restarts {max_restarts} {file_name} {args}"
)
.
format
(
rdzv_nnodes
=
rdzv_nnodes
,
nproc_per_node
=
nproc_per_node
,
rdzv_id
=
rdzv_id
,
master_addr
=
master_addr
,
master_port
=
master_port
,
max_restarts
=
max_restarts
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
.
split
(),
env
=
env
,
check
=
True
,
)
else
:
# NOTE: DO NOT USE shell=True to avoid security risk
process
=
subprocess
.
run
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.
format
(
nnodes
=
nnodes
,
node_rank
=
node_rank
,
nproc_per_node
=
nproc_per_node
,
master_addr
=
master_addr
,
master_port
=
master_port
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
.
split
(),
env
=
env
,
check
=
True
,
)
sys
.
exit
(
process
.
returncode
)
elif
command
in
COMMAND_MAP
:
COMMAND_MAP
[
command
]()
if
is_env_enabled
(
"USE_V1"
):
from
.v1
import
launcher
else
:
print
(
f
"Unknown command:
{
command
}
.
\n
{
USAGE
}
"
)
from
.
import
launcher
launcher
.
launch
()
if
__name__
==
"__main__"
:
...
...
src/llamafactory/data/collator.py
View file @
ca625f43
...
...
@@ -194,7 +194,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
elif
"video_second_per_grid"
in
mm_inputs
:
# for qwen2.5 omni
rope_index_kwargs
[
"second_per_grids"
]
=
mm_inputs
.
get
(
"video_second_per_grid"
)
if
getattr
(
self
.
model
.
config
,
"model_type"
,
None
)
==
"qwen2_5_omni_thinker"
:
# for qwen2.5 omni
if
getattr
(
self
.
model
.
config
,
"model_type"
,
None
)
in
[
"qwen2_5_omni_thinker"
,
"qwen3_omni_moe_thinker"
]:
rope_index_kwargs
[
"use_audio_in_video"
]
=
getattr
(
self
.
processor
,
"use_audio_in_video"
,
False
)
feature_attention_mask
=
mm_inputs
.
get
(
"feature_attention_mask"
,
None
)
if
feature_attention_mask
is
not
None
:
# FIXME: need to get video image lengths
...
...
@@ -205,16 +205,25 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features
[
"rope_deltas"
]
=
rope_deltas
-
(
1
-
rope_index_kwargs
[
"attention_mask"
]).
sum
(
dim
=-
1
).
unsqueeze
(
-
1
)
else
:
# for qwen
2
vl
else
:
# for qwen
vl
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
self
.
get_rope_func
(
**
rope_index_kwargs
)
if
(
self
.
model
is
not
None
and
getattr
(
self
.
model
.
config
,
"model_type"
,
None
)
in
[
"glm4v"
,
"qwen2_vl"
,
"qwen2_5_vl"
,
"qwen2_5_omni_thinker"
]
in
[
"glm4v"
,
"Keye"
,
"qwen2_vl"
,
"qwen2_5_vl"
,
"qwen2_5_omni_thinker"
,
"qwen3_omni_moe_thinker"
,
"qwen3_vl"
,
"qwen3_vl_moe"
,
]
and
(
"position_ids"
not
in
features
or
features
[
"position_ids"
].
dim
()
!=
3
)
):
raise
ValueError
(
"Qwen2-VL/Qwen2.5-Omni model
requires 3D position ids for mrope."
)
raise
ValueError
(
f
"
{
self
.
model
.
config
.
model_type
}
requires 3D position ids for mrope."
)
if
"cross_attention_mask"
in
mm_inputs
:
# for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask
=
mm_inputs
.
pop
(
"cross_attention_mask"
)
...
...
src/llamafactory/data/converter.py
View file @
ca625f43
...
...
@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Union
from
..extras
import
logging
from
.data_utils
import
Role
...
...
@@ -40,7 +40,7 @@ class DatasetConverter:
dataset_attr
:
"DatasetAttr"
data_args
:
"DataArguments"
def
_find_medias
(
self
,
medias
:
Union
[
"MediaType"
,
list
[
"MediaType"
],
None
])
->
Optional
[
list
[
"MediaType"
]
]
:
def
_find_medias
(
self
,
medias
:
Union
[
"MediaType"
,
list
[
"MediaType"
],
None
])
->
list
[
"MediaType"
]
|
None
:
r
"""Optionally concatenate media path to media dir when loading from local disk."""
if
medias
is
None
:
return
None
...
...
@@ -227,9 +227,150 @@ class SharegptDatasetConverter(DatasetConverter):
return
output
@
dataclass
class
OpenAIDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
tag_mapping
=
{
self
.
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
self
.
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
self
.
dataset_attr
.
observation_tag
:
Role
.
OBSERVATION
.
value
,
self
.
dataset_attr
.
function_tag
:
Role
.
FUNCTION
.
value
,
self
.
dataset_attr
.
system_tag
:
Role
.
SYSTEM
.
value
,
}
messages
=
example
[
self
.
dataset_attr
.
messages
]
if
(
self
.
dataset_attr
.
system_tag
and
len
(
messages
)
!=
0
and
messages
[
0
][
self
.
dataset_attr
.
role_tag
]
==
self
.
dataset_attr
.
system_tag
):
system
=
messages
[
0
][
self
.
dataset_attr
.
content_tag
]
messages
=
messages
[
1
:]
else
:
system
=
example
.
get
(
self
.
dataset_attr
.
system
,
""
)
if
self
.
dataset_attr
.
system
else
""
aligned_messages
=
[]
tool_responses
=
[]
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
role
=
message
[
self
.
dataset_attr
.
role_tag
]
content
=
message
[
self
.
dataset_attr
.
content_tag
]
if
role
in
[
self
.
dataset_attr
.
assistant_tag
,
self
.
dataset_attr
.
function_tag
]:
if
"tool_calls"
in
message
and
len
(
message
[
"tool_calls"
])
>
0
:
tool_calls_list
=
[
tool
[
"function"
]
for
tool
in
message
[
"tool_calls"
]]
content
=
json
.
dumps
(
tool_calls_list
,
ensure_ascii
=
False
)
role
=
self
.
dataset_attr
.
function_tag
if
role
==
self
.
dataset_attr
.
observation_tag
:
tool_responses
.
append
(
content
)
continue
elif
len
(
tool_responses
)
>
0
:
_content
=
"
\n
</tool_response>
\n
<tool_response>
\n
"
.
join
(
tool_responses
)
aligned_messages
.
append
(
{
"role"
:
Role
.
OBSERVATION
.
value
,
"content"
:
_content
,
}
)
tool_responses
=
[]
aligned_messages
.
append
(
{
"role"
:
tag_mapping
[
role
],
"content"
:
content
,
}
)
odd_tags
=
(
Role
.
USER
.
value
,
Role
.
OBSERVATION
.
value
)
even_tags
=
(
Role
.
ASSISTANT
.
value
,
Role
.
FUNCTION
.
value
)
accept_tags
=
(
odd_tags
,
even_tags
)
for
turn_idx
,
message
in
enumerate
(
aligned_messages
):
if
message
[
"role"
]
not
in
accept_tags
[
turn_idx
%
2
]:
logger
.
warning_rank0
(
f
"Invalid role tag in
{
messages
}
."
)
broken_data
=
True
break
if
(
not
self
.
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
self
.
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
):
logger
.
warning_rank0
(
f
"Invalid message count in
{
messages
}
."
)
broken_data
=
True
if
broken_data
:
logger
.
warning_rank0
(
"Skipping this abnormal example."
)
prompt
,
response
=
[],
[]
elif
self
.
dataset_attr
.
kto_tag
and
isinstance
(
example
[
self
.
dataset_attr
.
kto_tag
],
bool
):
# kto example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
example
[
self
.
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
self
.
dataset_attr
.
ranking
and
isinstance
(
example
[
self
.
dataset_attr
.
chosen
],
dict
)
and
isinstance
(
example
[
self
.
dataset_attr
.
rejected
],
dict
)
):
# pairwise example
chosen
=
example
[
self
.
dataset_attr
.
chosen
]
rejected
=
example
[
self
.
dataset_attr
.
rejected
]
if
(
chosen
[
self
.
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
self
.
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
):
logger
.
warning_rank0
(
f
"Invalid role tag in
{
[
chosen
,
rejected
]
}
."
)
broken_data
=
True
prompt
=
aligned_messages
response
=
[
{
"role"
:
tag_mapping
[
chosen
[
self
.
dataset_attr
.
role_tag
]],
"content"
:
chosen
[
self
.
dataset_attr
.
content_tag
],
},
{
"role"
:
tag_mapping
[
rejected
[
self
.
dataset_attr
.
role_tag
]],
"content"
:
rejected
[
self
.
dataset_attr
.
content_tag
],
},
]
else
:
# normal example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
tools
=
example
.
get
(
self
.
dataset_attr
.
tools
,
""
)
if
self
.
dataset_attr
.
tools
else
""
if
isinstance
(
tools
,
dict
)
or
isinstance
(
tools
,
list
):
tools
=
json
.
dumps
(
tools
,
ensure_ascii
=
False
)
short_system_prompt
=
"detailed thinking off"
if
not
system
:
if
not
tools
:
system
=
short_system_prompt
else
:
pass
else
:
if
not
tools
:
if
"detailed thinking on"
in
system
or
"detailed thinking off"
in
system
:
pass
else
:
system
+=
"
\n
"
+
short_system_prompt
else
:
system
+=
"
\n
"
output
=
{
"_prompt"
:
prompt
,
"_response"
:
response
,
"_system"
:
system
,
"_tools"
:
tools
,
"_images"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
images
])
if
self
.
dataset_attr
.
images
else
None
,
"_videos"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
videos
])
if
self
.
dataset_attr
.
videos
else
None
,
"_audios"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
audios
])
if
self
.
dataset_attr
.
audios
else
None
,
}
return
output
DATASET_CONVERTERS
=
{
"alpaca"
:
AlpacaDatasetConverter
,
"sharegpt"
:
SharegptDatasetConverter
,
"openai"
:
OpenAIDatasetConverter
,
}
...
...
src/llamafactory/data/data_utils.py
View file @
ca625f43
...
...
@@ -81,41 +81,48 @@ def split_dataset(
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]],
data_args
:
"DataArguments"
,
seed
:
int
,
)
->
"DatasetD
ict
"
:
r
"""Split the dataset and returns
a dataset
dict containing train set and validation set.
)
->
tuple
[
dict
,
d
ict
]
:
r
"""Split the dataset and returns
two
dict
s
containing train set and validation set.
Support both map dataset and iterable dataset.
Returns:
train_dict: Dictionary containing training data with key "train"
eval_dict: Dictionary containing evaluation data with keys "validation" or "validation_{name}"
"""
if
eval_dataset
is
not
None
and
data_args
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `eval_dataset` is not None."
)
dataset_dict
=
{}
# the train and eval better to in dict dtype and separately return for cpode clearly and good handle outside
train_dict
,
eval_dict
=
{},
{}
if
dataset
is
not
None
:
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
if
data_args
.
val_size
>
1e-6
:
if
data_args
.
streaming
:
dataset
_dict
[
"validation"
]
=
dataset
.
take
(
int
(
data_args
.
val_size
))
dataset
_dict
[
"train"
]
=
dataset
.
skip
(
int
(
data_args
.
val_size
))
eval
_dict
[
"validation"
]
=
dataset
.
take
(
int
(
data_args
.
val_size
))
train
_dict
[
"train"
]
=
dataset
.
skip
(
int
(
data_args
.
val_size
))
else
:
val_size
=
int
(
data_args
.
val_size
)
if
data_args
.
val_size
>
1
else
data_args
.
val_size
dataset_dic
t
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
dataset
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
dataset_dict
=
{
"train"
:
dataset
[
"train"
],
"validation"
:
datase
t
[
"test"
]
}
split_resul
t
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
train_dict
[
"train"
]
=
split_result
[
"train"
]
eval_dict
[
"validation"
]
=
split_resul
t
[
"test"
]
else
:
dataset
_dict
[
"train"
]
=
dataset
train
_dict
[
"train"
]
=
dataset
if
eval_dataset
is
not
None
:
if
isinstance
(
eval_dataset
,
dict
):
dataset_dict
.
update
({
f
"validation_
{
name
}
"
:
data
for
name
,
data
in
eval_dataset
.
items
()})
for
name
,
data
in
eval_dataset
.
items
():
eval_dict
[
f
"validation_
{
name
}
"
]
=
data
else
:
if
data_args
.
streaming
:
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
dataset
_dict
[
"validation"
]
=
eval_dataset
eval
_dict
[
"validation"
]
=
eval_dataset
return
DatasetDict
(
dataset
_dict
)
return
train_dict
,
eval
_dict
def
get_dataset_module
(
dataset
:
Union
[
"Dataset"
,
"DatasetDict"
])
->
"DatasetModule"
:
...
...
src/llamafactory/data/formatter.py
View file @
ca625f43
...
...
@@ -16,7 +16,6 @@ import json
import
re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
,
Union
from
typing_extensions
import
override
...
...
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
@
dataclass
class
Formatter
(
ABC
):
slots
:
SLOTS
=
field
(
default_factory
=
list
)
tool_format
:
Optional
[
str
]
=
None
tool_format
:
str
|
None
=
None
@
abstractmethod
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
r
"""Forms a list of slots according to the inputs to encode."""
...
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]
]
:
def
extract
(
self
,
content
:
str
)
->
str
|
list
[
"FunctionCall"
]:
r
"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
...
...
@@ -97,28 +96,46 @@ class FunctionFormatter(StringFormatter):
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
:
str
=
kwargs
.
pop
(
"content"
)
regex
=
re
.
compile
(
r
"<think>(.*)</think>"
,
re
.
DOTALL
)
thought
=
re
.
search
(
regex
,
content
)
if
thought
:
content
=
content
.
replace
(
thought
.
group
(
0
),
""
)
functions
:
list
[
FunctionCall
]
=
[]
try
:
tool_calls
=
json
.
loads
(
content
)
if
not
isinstance
(
tool_calls
,
list
):
# parallel function call
tool_calls
=
[
tool_calls
]
for
tool_call
in
tool_calls
:
functions
.
append
(
FunctionCall
(
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
))
)
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
."
)
# flat string
function_str
=
self
.
tool_utils
.
function_formatter
(
functions
)
if
thought
:
function_str
=
thought
.
group
(
0
)
+
function_str
thought_words
=
kwargs
.
pop
(
"thought_words"
,
None
)
tool_call_words
=
kwargs
.
pop
(
"tool_call_words"
,
None
)
def
_parse_functions
(
json_content
:
str
)
->
list
[
"FunctionCall"
]:
try
:
tool_calls
=
json
.
loads
(
json_content
)
if
not
isinstance
(
tool_calls
,
list
):
# parallel function call
tool_calls
=
[
tool_calls
]
return
[
FunctionCall
(
tc
[
"name"
],
json
.
dumps
(
tc
[
"arguments"
],
ensure_ascii
=
False
))
for
tc
in
tool_calls
]
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
."
)
tool_call_match
=
None
if
tool_call_words
and
len
(
tool_call_words
)
==
2
:
tool_call_regex
=
re
.
compile
(
rf
"
{
re
.
escape
(
tool_call_words
[
0
])
}
(.*?)
{
re
.
escape
(
tool_call_words
[
1
])
}
"
,
re
.
DOTALL
)
tool_call_match
=
re
.
search
(
tool_call_regex
,
content
)
if
tool_call_match
is
None
:
thought_match
=
None
if
thought_words
and
len
(
thought_words
)
==
2
:
regex
=
re
.
compile
(
rf
"
{
re
.
escape
(
thought_words
[
0
])
}
(.*?)
{
re
.
escape
(
thought_words
[
1
])
}
"
,
re
.
DOTALL
)
thought_match
=
re
.
search
(
regex
,
content
)
if
thought_match
:
json_part
=
content
.
replace
(
thought_match
.
group
(
0
),
""
)
else
:
json_part
=
content
functions
=
_parse_functions
(
json_part
)
function_str
=
self
.
tool_utils
.
function_formatter
(
functions
)
if
thought_match
:
function_str
=
thought_match
.
group
(
0
)
+
function_str
else
:
thought_content
=
content
.
replace
(
tool_call_match
.
group
(
0
),
""
)
functions
=
_parse_functions
(
tool_call_match
.
group
(
1
))
function_str
=
self
.
tool_utils
.
function_formatter
(
functions
)
function_str
=
thought_content
+
function_str
return
super
().
apply
(
content
=
function_str
)
...
...
@@ -138,5 +155,5 @@ class ToolFormatter(Formatter):
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
."
)
# flat string
@
override
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]
]
:
def
extract
(
self
,
content
:
str
)
->
str
|
list
[
"FunctionCall"
]:
return
self
.
tool_utils
.
tool_extractor
(
content
)
src/llamafactory/data/loader.py
View file @
ca625f43
...
...
@@ -16,7 +16,7 @@ import os
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
numpy
as
np
from
datasets
import
Dataset
,
load_dataset
,
load_from_disk
from
datasets
import
Dataset
,
DatasetDict
,
load_dataset
,
load_from_disk
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
...
...
@@ -137,7 +137,6 @@ def _load_single_dataset(
cache_dir
=
model_args
.
cache_dir
,
token
=
model_args
.
hf_hub_token
,
num_proc
=
data_args
.
preprocessing_num_workers
,
trust_remote_code
=
model_args
.
trust_remote_code
,
streaming
=
data_args
.
streaming
and
dataset_attr
.
load_from
!=
"file"
,
)
if
data_args
.
streaming
and
dataset_attr
.
load_from
==
"file"
:
...
...
@@ -163,13 +162,13 @@ def _load_single_dataset(
def
_get_merged_dataset
(
dataset_names
:
Optional
[
list
[
str
]
]
,
dataset_names
:
list
[
str
]
|
None
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
return_dict
:
bool
=
False
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]
]
:
)
->
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]
|
None
:
r
"""Return the merged datasets in the standard format."""
if
dataset_names
is
None
:
return
None
...
...
@@ -228,7 +227,7 @@ def _get_dataset_processor(
def
_get_preprocessed_dataset
(
dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]
]
,
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
]
|
None
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
...
...
@@ -236,7 +235,7 @@ def _get_preprocessed_dataset(
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
is_eval
:
bool
=
False
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]
]
:
)
->
Union
[
"Dataset"
,
"IterableDataset"
]
|
None
:
r
"""Preprocesses the dataset, including format checking and tokenization."""
if
dataset
is
None
:
return
None
...
...
@@ -312,20 +311,22 @@ def get_dataset(
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
,
local
=
(
not
data_args
.
data_shared_file_system
)):
dataset
=
_get_preprocessed_dataset
(
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
if
isinstance
(
eval_dataset
,
dict
):
for
eval_name
,
eval_data
in
eval_dataset
.
items
():
eval_dataset
[
eval_name
]
=
_get_preprocessed_dataset
(
eval_data
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
else
:
eval_dataset
=
_get_preprocessed_dataset
(
eval_dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
# move front to make sure eval_dataset(if contain or split) can preprocessed appropriately
train_dict
,
eval_dict
=
split_dataset
(
dataset
,
eval_dataset
,
data_args
,
seed
=
training_args
.
seed
)
if
"train"
in
train_dict
:
train_dict
[
"train"
]
=
_get_preprocessed_dataset
(
train_dict
[
"train"
],
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
dataset_dict
=
split_dataset
(
dataset
,
eval_dataset
,
data_args
,
seed
=
training_args
.
seed
)
for
key
in
eval_dict
:
eval_dict
[
key
]
=
_get_preprocessed_dataset
(
eval_dict
[
key
],
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
# Combine train and eval dictionaries
dataset_dict
=
DatasetDict
({
**
train_dict
,
**
eval_dict
})
if
data_args
.
tokenized_path
is
not
None
:
# save tokenized dataset to disk
if
training_args
.
should_save
:
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
...
...
src/llamafactory/data/mm_plugin.py
View file @
ca625f43
...
...
@@ -22,10 +22,11 @@ import re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
io
import
BytesIO
from
typing
import
TYPE_CHECKING
,
BinaryIO
,
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
BinaryIO
,
Literal
,
NotRequired
,
Optional
,
TypedDict
,
Union
import
numpy
as
np
import
torch
import
torchaudio
from
transformers.image_utils
import
get_image_size
,
is_valid_image
,
to_numpy_array
from
transformers.models.mllama.processing_mllama
import
(
convert_sparse_cross_attention_mask_to_dense
,
...
...
@@ -34,16 +35,7 @@ from transformers.models.mllama.processing_mllama import (
from
typing_extensions
import
override
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.packages
import
(
is_librosa_available
,
is_pillow_available
,
is_pyav_available
,
is_transformers_version_greater_than
,
)
if
is_librosa_available
():
import
librosa
from
..extras.packages
import
is_pillow_available
,
is_pyav_available
,
is_transformers_version_greater_than
if
is_pillow_available
():
...
...
@@ -68,15 +60,28 @@ if TYPE_CHECKING:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.feature_extraction_sequence_utils
import
SequenceFeatureExtractor
from
transformers.image_processing_utils
import
BaseImageProcessor
from
transformers.video_processing_utils
import
BaseVideoProcessor
class
EncodedImage
(
TypedDict
):
path
:
Optional
[
str
]
bytes
:
Optional
[
bytes
]
path
:
str
|
None
bytes
:
bytes
|
None
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
BinaryIO
,
ImageObject
]
VideoInput
=
Union
[
str
,
BinaryIO
,
list
[
list
[
ImageInput
]]]
AudioInput
=
Union
[
str
,
BinaryIO
,
NDArray
]
class
RegularizedImageOutput
(
TypedDict
):
images
:
list
[
ImageObject
]
class
RegularizedVideoOutput
(
TypedDict
):
videos
:
list
[
list
[
ImageObject
]]
durations
:
list
[
float
]
fps_per_video
:
NotRequired
[
list
[
float
]]
class
RegularizedAudioOutput
(
TypedDict
):
audios
:
list
[
NDArray
]
sampling_rates
:
list
[
float
]
class
MMProcessor
(
ProcessorMixin
):
patch_size
:
int
image_seq_length
:
int
...
...
@@ -134,14 +139,14 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
def
_check_video_is_nested_images
(
video
:
"VideoInput"
)
->
bool
:
r
"""Check if the video is nested images."""
return
isinstance
(
video
,
list
)
and
all
(
isinstance
(
frame
,
(
str
,
BinaryIO
,
dict
))
for
frame
in
video
)
return
isinstance
(
video
,
list
)
and
all
(
isinstance
(
frame
,
(
str
,
BinaryIO
,
dict
,
ImageObject
))
for
frame
in
video
)
@
dataclass
class
MMPluginMixin
:
image_token
:
Optional
[
str
]
video_token
:
Optional
[
str
]
audio_token
:
Optional
[
str
]
image_token
:
str
|
None
video_token
:
str
|
None
audio_token
:
str
|
None
expand_mm_tokens
:
bool
=
True
def
_validate_input
(
...
...
@@ -244,7 +249,7 @@ class MMPluginMixin:
sample_frames
=
min
(
total_frames
,
video_maxlen
,
sample_frames
)
return
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
def
_regularize_images
(
self
,
images
:
list
[
"ImageInput"
],
**
kwargs
)
->
dict
[
str
,
list
[
"ImageObject"
]]
:
def
_regularize_images
(
self
,
images
:
list
[
"ImageInput"
],
**
kwargs
)
->
"RegularizedImageOutput"
:
r
"""Regularize images to avoid error. Including reading and pre-processing."""
results
=
[]
for
image
in
images
:
...
...
@@ -265,9 +270,10 @@ class MMPluginMixin:
return
{
"images"
:
results
}
def
_regularize_videos
(
self
,
videos
:
list
[
"VideoInput"
],
**
kwargs
)
->
dict
[
str
,
list
[
list
[
"ImageObject"
]]]
:
def
_regularize_videos
(
self
,
videos
:
list
[
"VideoInput"
],
**
kwargs
)
->
"RegularizedVideoOutput"
:
r
"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results
=
[]
durations
=
[]
for
video
in
videos
:
frames
:
list
[
ImageObject
]
=
[]
if
_check_video_is_nested_images
(
video
):
...
...
@@ -275,6 +281,7 @@ class MMPluginMixin:
if
not
is_valid_image
(
frame
)
and
not
isinstance
(
frame
,
dict
)
and
not
os
.
path
.
exists
(
frame
):
raise
ValueError
(
"Invalid image found in video frames."
)
frames
=
video
durations
.
append
(
len
(
frames
)
/
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
...
...
@@ -284,19 +291,31 @@ class MMPluginMixin:
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
if
video_stream
.
duration
is
None
:
durations
.
append
(
len
(
frames
)
/
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
durations
.
append
(
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
results
.
append
(
frames
)
return
{
"videos"
:
results
}
return
{
"videos"
:
results
,
"durations"
:
durations
}
def
_regularize_audios
(
self
,
audios
:
list
[
"AudioInput"
],
sampling_rate
:
float
,
**
kwargs
)
->
dict
[
str
,
Union
[
list
[
"NDArray"
],
list
[
float
]]]
:
)
->
"RegularizedAudioOutput"
:
r
"""Regularizes audios to avoid error. Including reading and resampling."""
results
,
sampling_rates
=
[],
[]
for
audio
in
audios
:
if
not
isinstance
(
audio
,
np
.
ndarray
):
audio
,
sampling_rate
=
librosa
.
load
(
audio
,
sr
=
sampling_rate
)
audio
,
sr
=
torchaudio
.
load
(
audio
)
if
audio
.
shape
[
0
]
>
1
:
audio
=
audio
.
mean
(
dim
=
0
,
keepdim
=
True
)
if
sr
!=
sampling_rate
:
audio
=
torchaudio
.
functional
.
resample
(
audio
,
sr
,
sampling_rate
)
audio
=
audio
.
squeeze
(
0
).
numpy
()
results
.
append
(
audio
)
sampling_rates
.
append
(
sampling_rate
)
...
...
@@ -309,7 +328,7 @@ class MMPluginMixin:
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"MMProcessor"
,
imglens
:
Optional
[
list
[
int
]
]
=
None
,
imglens
:
list
[
int
]
|
None
=
None
,
)
->
dict
[
str
,
"torch.Tensor"
]:
r
"""Process visual inputs.
...
...
@@ -407,13 +426,13 @@ class BasePlugin(MMPluginMixin):
def
process_token_ids
(
self
,
input_ids
:
list
[
int
],
labels
:
Optional
[
list
[
int
]
]
,
labels
:
list
[
int
]
|
None
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"MMProcessor"
],
)
->
tuple
[
list
[
int
],
Optional
[
list
[
int
]
]
]:
)
->
tuple
[
list
[
int
],
list
[
int
]
|
None
]:
r
"""Pre-process token ids after tokenization for VLMs."""
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
input_ids
,
labels
...
...
@@ -446,6 +465,57 @@ class BasePlugin(MMPluginMixin):
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
ErnieVLPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
merge_length
:
int
=
getattr
(
image_processor
,
"merge_size"
)
**
2
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
else
:
image_grid_thw
=
[
None
]
*
len
(
images
)
video_grid_thw
=
[
None
]
*
len
(
videos
)
image_idx
,
video_idx
=
0
,
0
for
message
in
messages
:
content
=
message
[
"content"
]
image_token
=
self
.
image_token
or
"<|IMAGE_PLACEHOLDER|>"
video_token
=
self
.
video_token
or
"<|VIDEO_PLACEHOLDER|>"
while
IMAGE_PLACEHOLDER
in
content
:
image_seqlen
=
image_grid_thw
[
image_idx
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"Picture
{
image_idx
+
1
}
:<|IMAGE_START|>
{
image_token
*
image_seqlen
}
<|IMAGE_END|>"
,
1
,
)
image_idx
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
video_seqlen
=
video_grid_thw
[
video_idx
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"Video
{
video_idx
+
1
}
:<|VIDEO_START|>
{
video_token
*
video_seqlen
}
<|VIDEO_END|>"
,
1
,
)
video_idx
+=
1
message
[
"content"
]
=
content
return
messages
@
dataclass
class
Gemma3Plugin
(
BasePlugin
):
@
override
...
...
@@ -1235,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
def
process_token_ids
(
self
,
input_ids
:
list
[
int
],
labels
:
Optional
[
list
[
int
]
]
,
labels
:
list
[
int
]
|
None
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"MMProcessor"
],
)
->
tuple
[
list
[
int
],
Optional
[
list
[
int
]
]
]:
)
->
tuple
[
list
[
int
],
list
[
int
]
|
None
]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_images
=
len
(
images
)
image_seqlen
=
processor
.
image_seq_length
if
self
.
expand_mm_tokens
else
0
# skip mm token
...
...
@@ -1397,6 +1467,9 @@ class Qwen2AudioPlugin(BasePlugin):
@
dataclass
class
Qwen2VLPlugin
(
BasePlugin
):
vision_bos_token
:
str
=
"<|vision_start|>"
vision_eos_token
:
str
=
"<|vision_end|>"
@
override
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
image
=
super
().
_preprocess_image
(
image
,
**
kwargs
)
...
...
@@ -1415,10 +1488,8 @@ class Qwen2VLPlugin(BasePlugin):
return
image
@
override
def
_regularize_videos
(
self
,
videos
:
list
[
"VideoInput"
],
**
kwargs
)
->
dict
[
str
,
Union
[
list
[
list
[
"ImageObject"
]],
list
[
float
]]]:
results
,
fps_per_video
=
[],
[]
def
_regularize_videos
(
self
,
videos
:
list
[
"VideoInput"
],
**
kwargs
)
->
"RegularizedVideoOutput"
:
results
,
fps_per_video
,
durations
=
[],
[],
[]
for
video
in
videos
:
frames
:
list
[
ImageObject
]
=
[]
if
_check_video_is_nested_images
(
video
):
...
...
@@ -1428,6 +1499,7 @@ class Qwen2VLPlugin(BasePlugin):
frames
=
video
fps_per_video
.
append
(
kwargs
.
get
(
"video_fps"
,
2.0
))
durations
.
append
(
len
(
frames
)
/
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
...
...
@@ -1439,8 +1511,10 @@ class Qwen2VLPlugin(BasePlugin):
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
kwargs
.
get
(
"video_fps"
,
2.0
))
durations
.
append
(
len
(
frames
)
/
kwargs
.
get
(
"video_fps"
,
2.0
))
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
durations
.
append
(
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
if
len
(
frames
)
%
2
!=
0
:
frames
.
append
(
frames
[
-
1
])
...
...
@@ -1448,7 +1522,7 @@ class Qwen2VLPlugin(BasePlugin):
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)[
"images"
]
results
.
append
(
frames
)
return
{
"videos"
:
results
,
"fps_per_video"
:
fps_per_video
}
return
{
"videos"
:
results
,
"fps_per_video"
:
fps_per_video
,
"durations"
:
durations
}
@
override
def
_get_mm_inputs
(
...
...
@@ -1459,6 +1533,7 @@ class Qwen2VLPlugin(BasePlugin):
processor
:
"MMProcessor"
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
BaseVideoProcessor
=
getattr
(
processor
,
"video_processor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
...
...
@@ -1476,7 +1551,7 @@ class Qwen2VLPlugin(BasePlugin):
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
mm_inputs
.
update
(
image
_processor
(
images
=
None
,
videos
=
video_data
[
"videos"
],
return_tensors
=
"pt"
))
mm_inputs
.
update
(
video
_processor
(
videos
=
video_data
[
"videos"
],
return_tensors
=
"pt"
))
temporal_patch_size
:
int
=
getattr
(
image_processor
,
"temporal_patch_size"
,
2
)
if
"second_per_grid_ts"
in
processor
.
model_input_names
:
mm_inputs
[
"second_per_grid_ts"
]
=
[
temporal_patch_size
/
fps
for
fps
in
video_data
[
"fps_per_video"
]]
...
...
@@ -1512,15 +1587,142 @@ class Qwen2VLPlugin(BasePlugin):
while
IMAGE_PLACEHOLDER
in
content
:
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
image_token
*
image_seqlen
}
<|vision_end|>"
,
1
IMAGE_PLACEHOLDER
,
f
"
{
self
.
vision_bos_token
}{
self
.
image_token
*
image_seqlen
}{
self
.
vision_eos_token
}
"
,
1
,
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
video_seqlen
=
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
video_token
*
video_seqlen
}
<|vision_end|>"
,
1
VIDEO_PLACEHOLDER
,
f
"
{
self
.
vision_bos_token
}{
self
.
video_token
*
video_seqlen
}{
self
.
vision_eos_token
}
"
,
1
,
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
return
messages
@
dataclass
class
Qwen3VLPlugin
(
Qwen2VLPlugin
):
@
override
def
_get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"MMProcessor"
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"video_processor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
video_metadata
=
[
{
"fps"
:
getattr
(
processor
,
"video_fps"
,
24.0
),
"duration"
:
duration
,
"total_num_frames"
:
len
(
video
)}
for
video
,
duration
in
zip
(
videos
[
"videos"
],
videos
[
"durations"
])
]
mm_inputs
.
update
(
video_processor
(
videos
=
videos
[
"videos"
],
video_metadata
=
video_metadata
,
fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
return_metadata
=
True
,
)
)
temporal_patch_size
:
int
=
getattr
(
image_processor
,
"temporal_patch_size"
,
2
)
if
"second_per_grid_ts"
in
processor
.
model_input_names
:
mm_inputs
[
"second_per_grid_ts"
]
=
[
temporal_patch_size
/
fps
for
fps
in
videos
[
"fps_per_video"
]]
return
mm_inputs
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
self
.
_validate_messages
(
messages
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
video_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"video_processor"
)
image_merge_length
:
int
=
getattr
(
image_processor
,
"merge_size"
)
**
2
video_merge_length
:
int
=
getattr
(
video_processor
,
"merge_size"
)
**
2
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
num_frames
=
video_grid_thw
[
0
][
0
]
if
len
(
video_grid_thw
)
>
0
else
0
# hard code for now
video_metadata
=
mm_inputs
.
get
(
"video_metadata"
,
{})
else
:
image_grid_thw
=
[
None
]
*
len
(
images
)
video_grid_thw
=
[
None
]
*
len
(
videos
)
num_frames
=
0
timestamps
=
[
0
]
for
idx
,
message
in
enumerate
(
messages
):
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
image_seqlen
=
(
image_grid_thw
[
num_image_tokens
].
prod
()
//
image_merge_length
if
self
.
expand_mm_tokens
else
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"
{
self
.
vision_bos_token
}{
self
.
image_token
*
image_seqlen
}{
self
.
vision_eos_token
}
"
,
1
,
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
metadata
=
video_metadata
[
idx
]
timestamps
=
processor
.
_calculate_timestamps
(
metadata
.
frames_indices
,
metadata
.
fps
,
video_processor
.
merge_size
,
)
video_structure
=
""
for
frame_index
in
range
(
num_frames
):
video_seqlen
=
(
video_grid_thw
[
num_video_tokens
][
1
:].
prod
()
//
video_merge_length
if
self
.
expand_mm_tokens
else
1
)
timestamp_sec
=
timestamps
[
frame_index
]
frame_structure
=
(
f
"<
{
timestamp_sec
:.
1
f
}
seconds>"
f
"
{
self
.
vision_bos_token
}{
self
.
video_token
*
video_seqlen
}{
self
.
vision_eos_token
}
"
)
video_structure
+=
frame_structure
else
:
video_structure
=
f
"
{
self
.
vision_bos_token
}{
self
.
video_token
}{
self
.
vision_eos_token
}
"
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
video_structure
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
...
...
@@ -1559,7 +1761,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
)
# prepare video metadata
video_metadata
=
[
{
"fps"
:
2
,
"duration"
:
len
(
video
),
"total_frames"
:
len
(
video
)}
for
video
in
video_data
[
"videos"
]
{
"fps"
:
2
,
"duration"
:
duration
,
"total_frames"
:
len
(
video
)}
for
video
,
duration
in
zip
(
video_data
[
"videos"
],
video_data
[
"durations"
])
]
mm_inputs
.
update
(
video_processor
(
images
=
None
,
videos
=
video_data
[
"videos"
],
video_metadata
=
video_metadata
))
...
...
@@ -1630,6 +1833,9 @@ class GLM4VPlugin(Qwen2VLPlugin):
)
video_structure
+=
frame_structure
if
not
self
.
expand_mm_tokens
:
video_structure
=
self
.
video_token
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|begin_of_video|>
{
video_structure
}
<|end_of_video|>"
,
1
)
num_video_tokens
+=
1
...
...
@@ -1655,7 +1861,11 @@ class GLM4VPlugin(Qwen2VLPlugin):
return
mm_inputs
@
dataclass
class
Qwen2OmniPlugin
(
Qwen2VLPlugin
):
audio_bos_token
:
str
=
"<|audio_start|>"
audio_eos_token
:
str
=
"<|audio_end|>"
@
override
def
_get_mm_inputs
(
self
,
...
...
@@ -1665,6 +1875,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor
:
"MMProcessor"
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
BaseVideoProcessor
=
getattr
(
processor
,
"video_processor"
,
None
)
feature_extractor
:
SequenceFeatureExtractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
...
...
@@ -1683,7 +1894,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
mm_inputs
.
update
(
image
_processor
(
images
=
None
,
videos
=
video_dict
[
"videos"
],
return_tensors
=
"pt"
))
mm_inputs
.
update
(
video
_processor
(
videos
=
video_dict
[
"videos"
],
return_tensors
=
"pt"
))
temporal_patch_size
:
int
=
getattr
(
image_processor
,
"temporal_patch_size"
,
2
)
mm_inputs
[
"video_second_per_grid"
]
=
torch
.
tensor
(
[
temporal_patch_size
/
fps
for
fps
in
video_dict
[
"fps_per_video"
]]
...
...
@@ -1729,8 +1940,14 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
[])
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
[])
if
"feature_attention_mask"
in
mm_inputs
:
input_lengths
=
(
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
).
numpy
()
-
1
)
//
2
+
1
audio_lengths
=
(
input_lengths
-
2
)
//
2
+
1
if
processor
.
__class__
.
__name__
==
"Qwen3OmniMoeProcessor"
:
# for qwen3omni
input_lengths
=
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
)
input_lengths_leave
=
input_lengths
%
100
feature_lengths
=
(
input_lengths_leave
-
1
)
//
2
+
1
audio_lengths
=
((
feature_lengths
-
1
)
//
2
+
1
-
1
)
//
2
+
1
+
(
input_lengths
//
100
)
*
13
else
:
input_lengths
=
(
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
).
numpy
()
-
1
)
//
2
+
1
audio_lengths
=
(
input_lengths
-
2
)
//
2
+
1
else
:
mm_inputs
=
{}
image_grid_thw
=
[
None
]
*
len
(
images
)
...
...
@@ -1742,7 +1959,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
while
IMAGE_PLACEHOLDER
in
content
:
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
image_token
*
image_seqlen
}
<|vision_eos|>"
,
1
IMAGE_PLACEHOLDER
,
f
"
{
self
.
vision_bos_token
}{
self
.
image_token
*
image_seqlen
}{
self
.
vision_eos_token
}
"
,
1
,
)
num_image_tokens
+=
1
...
...
@@ -1779,7 +1998,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_chunk_indices
=
processor
.
get_chunked_index
(
video_t_index
,
t_ntoken_per_chunk
)
audio_chunk_indices
=
processor
.
get_chunked_index
(
audio_t_index
,
t_ntoken_per_chunk
)
placeholder_string
=
""
placeholder_string
+=
"<|
vision_bos
|>"
+
"<|
audio_bos
|>"
placeholder_string
+=
self
.
vision_bos
_token
+
self
.
audio_bos
_token
for
j
in
range
(
max
(
len
(
video_chunk_indices
),
len
(
audio_chunk_indices
))):
video_chunk_index
=
video_chunk_indices
[
j
]
if
j
<
len
(
video_chunk_indices
)
else
None
audio_chunk_index
=
audio_chunk_indices
[
j
]
if
j
<
len
(
audio_chunk_indices
)
else
None
...
...
@@ -1789,7 +2008,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if
audio_chunk_index
is
not
None
:
placeholder_string
+=
self
.
audio_token
*
(
audio_chunk_index
[
1
]
-
audio_chunk_index
[
0
])
placeholder_string
+=
"<|
audio_eos
|>"
+
"<|
vision_eos
|>"
placeholder_string
+=
self
.
audio_eos
_token
+
self
.
vision_eos
_token
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
placeholder_string
,
1
)
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
""
,
1
)
num_audio_tokens
+=
1
...
...
@@ -1798,7 +2017,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
while
AUDIO_PLACEHOLDER
in
content
:
audio_seqlen
=
audio_lengths
[
num_audio_tokens
]
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"<|audio_bos|>
{
self
.
audio_token
*
audio_seqlen
}
<|audio_eos|>"
,
1
AUDIO_PLACEHOLDER
,
f
"
{
self
.
audio_bos_token
}{
self
.
audio_token
*
audio_seqlen
}{
self
.
audio_eos_token
}
"
,
1
,
)
num_audio_tokens
+=
1
...
...
@@ -1807,7 +2028,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
video_token
*
video_seqlen
}
<|vision_eos|>"
,
1
VIDEO_PLACEHOLDER
,
f
"
{
self
.
vision_bos_token
}{
self
.
video_token
*
video_seqlen
}{
self
.
vision_eos_token
}
"
,
1
,
)
num_video_tokens
+=
1
...
...
@@ -1871,6 +2094,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS
=
{
"base"
:
BasePlugin
,
"ernie_vl"
:
ErnieVLPlugin
,
"gemma3"
:
Gemma3Plugin
,
"glm4v"
:
GLM4VPlugin
,
"gemma3n"
:
Gemma3nPlugin
,
...
...
@@ -1887,6 +2111,7 @@ PLUGINS = {
"qwen2_audio"
:
Qwen2AudioPlugin
,
"qwen2_omni"
:
Qwen2OmniPlugin
,
"qwen2_vl"
:
Qwen2VLPlugin
,
"qwen3_vl"
:
Qwen3VLPlugin
,
"video_llava"
:
VideoLlavaPlugin
,
}
...
...
@@ -1901,12 +2126,13 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
def
get_mm_plugin
(
name
:
str
,
image_token
:
Optional
[
str
]
=
None
,
video_token
:
Optional
[
str
]
=
None
,
audio_token
:
Optional
[
str
]
=
None
,
image_token
:
str
|
None
=
None
,
video_token
:
str
|
None
=
None
,
audio_token
:
str
|
None
=
None
,
**
kwargs
,
)
->
"BasePlugin"
:
r
"""Get plugin for multimodal inputs."""
if
name
not
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin `
{
name
}
` not found."
)
return
PLUGINS
[
name
](
image_token
,
video_token
,
audio_token
)
return
PLUGINS
[
name
](
image_token
,
video_token
,
audio_token
,
**
kwargs
)
src/llamafactory/data/parser.py
View file @
ca625f43
...
...
@@ -15,7 +15,7 @@
import
json
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Literal
,
Optional
from
typing
import
Any
,
Literal
from
huggingface_hub
import
hf_hub_download
...
...
@@ -30,43 +30,43 @@ class DatasetAttr:
# basic configs
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
,
"script"
,
"file"
]
dataset_name
:
str
formatting
:
Literal
[
"alpaca"
,
"sharegpt"
]
=
"alpaca"
formatting
:
Literal
[
"alpaca"
,
"sharegpt"
,
"openai"
]
=
"alpaca"
ranking
:
bool
=
False
# extra configs
subset
:
Optional
[
str
]
=
None
subset
:
str
|
None
=
None
split
:
str
=
"train"
folder
:
Optional
[
str
]
=
None
num_samples
:
Optional
[
int
]
=
None
folder
:
str
|
None
=
None
num_samples
:
int
|
None
=
None
# common columns
system
:
Optional
[
str
]
=
None
tools
:
Optional
[
str
]
=
None
images
:
Optional
[
str
]
=
None
videos
:
Optional
[
str
]
=
None
audios
:
Optional
[
str
]
=
None
system
:
str
|
None
=
None
tools
:
str
|
None
=
None
images
:
str
|
None
=
None
videos
:
str
|
None
=
None
audios
:
str
|
None
=
None
# dpo columns
chosen
:
Optional
[
str
]
=
None
rejected
:
Optional
[
str
]
=
None
kto_tag
:
Optional
[
str
]
=
None
chosen
:
str
|
None
=
None
rejected
:
str
|
None
=
None
kto_tag
:
str
|
None
=
None
# alpaca columns
prompt
:
Optional
[
str
]
=
"instruction"
query
:
Optional
[
str
]
=
"input"
response
:
Optional
[
str
]
=
"output"
history
:
Optional
[
str
]
=
None
prompt
:
str
|
None
=
"instruction"
query
:
str
|
None
=
"input"
response
:
str
|
None
=
"output"
history
:
str
|
None
=
None
# sharegpt columns
messages
:
Optional
[
str
]
=
"conversations"
messages
:
str
|
None
=
"conversations"
# sharegpt tags
role_tag
:
Optional
[
str
]
=
"from"
content_tag
:
Optional
[
str
]
=
"value"
user_tag
:
Optional
[
str
]
=
"human"
assistant_tag
:
Optional
[
str
]
=
"gpt"
observation_tag
:
Optional
[
str
]
=
"observation"
function_tag
:
Optional
[
str
]
=
"function_call"
system_tag
:
Optional
[
str
]
=
"system"
role_tag
:
str
|
None
=
"from"
content_tag
:
str
|
None
=
"value"
user_tag
:
str
|
None
=
"human"
assistant_tag
:
str
|
None
=
"gpt"
observation_tag
:
str
|
None
=
"observation"
function_tag
:
str
|
None
=
"function_call"
system_tag
:
str
|
None
=
"system"
def
__repr__
(
self
)
->
str
:
return
self
.
dataset_name
def
set_attr
(
self
,
key
:
str
,
obj
:
dict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
def
set_attr
(
self
,
key
:
str
,
obj
:
dict
[
str
,
Any
],
default
:
Any
|
None
=
None
)
->
None
:
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
def
join
(
self
,
attr
:
dict
[
str
,
Any
])
->
None
:
...
...
@@ -90,12 +90,14 @@ class DatasetAttr:
self
.
set_attr
(
tag
,
attr
[
"tags"
])
def
get_dataset_list
(
dataset_names
:
Optional
[
list
[
str
]
]
,
dataset_dir
:
str
)
->
list
[
"DatasetAttr"
]:
def
get_dataset_list
(
dataset_names
:
list
[
str
]
|
None
,
dataset_dir
:
str
|
dict
)
->
list
[
"DatasetAttr"
]:
r
"""Get the attributes of the datasets."""
if
dataset_names
is
None
:
dataset_names
=
[]
if
dataset_dir
==
"ONLINE"
:
if
isinstance
(
dataset_dir
,
dict
):
dataset_info
=
dataset_dir
elif
dataset_dir
==
"ONLINE"
:
dataset_info
=
None
else
:
if
dataset_dir
.
startswith
(
"REMOTE:"
):
...
...
src/llamafactory/data/processor/supervised.py
View file @
ca625f43
...
...
@@ -62,7 +62,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
if
self
.
data_args
.
train_on_prompt
:
source_label
=
source_ids
elif
self
.
template
.
efficient_eos
:
elif
self
.
template
.
efficient_eos
and
turn_idx
!=
0
:
source_label
=
[
self
.
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
else
:
source_label
=
[
IGNORE_INDEX
]
*
source_len
...
...
src/llamafactory/data/template.py
View file @
ca625f43
...
...
@@ -49,6 +49,7 @@ class Template:
default_system
:
str
stop_words
:
list
[
str
]
thought_words
:
tuple
[
str
,
str
]
tool_call_words
:
tuple
[
str
,
str
]
efficient_eos
:
bool
replace_eos
:
bool
replace_jinja_template
:
bool
...
...
@@ -96,7 +97,7 @@ class Template:
def
add_thought
(
self
,
content
:
str
=
""
)
->
str
:
r
"""Add empty thought to assistant message."""
return
f
"
{
self
.
thought_words
[
0
]
}
\n\n
{
self
.
thought_words
[
1
]
}
\n\n
"
+
content
return
f
"
{
self
.
thought_words
[
0
]
}{
self
.
thought_words
[
1
]
}
"
+
content
def
remove_thought
(
self
,
content
:
str
)
->
str
:
r
"""Remove thought from assistant message."""
...
...
@@ -156,7 +157,9 @@ class Template:
elif
message
[
"role"
]
==
Role
.
OBSERVATION
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
FUNCTION
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
],
thought_words
=
self
.
thought_words
,
tool_call_words
=
self
.
tool_call_words
)
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
...
...
@@ -199,9 +202,12 @@ class Template:
logger
.
info_rank0
(
f
"Add pad token:
{
tokenizer
.
pad_token
}
"
)
if
stop_words
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
try
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
except
TypeError
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
))
logger
.
info_rank0
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
if
num_added_tokens
>
0
:
logger
.
warning_rank0
(
"New tokens have been added, make sure `resize_vocab` is True."
)
...
...
@@ -416,8 +422,8 @@ class ReasoningTemplate(Template):
prompt_ids
,
response_ids
=
super
().
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
if
(
self
.
thought_words
[
0
]
not
in
messages
[
-
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
-
1
][
"content"
]
self
.
thought_words
[
0
]
.
strip
()
not
in
messages
[
-
1
][
"content"
]
and
self
.
thought_words
[
1
]
.
strip
()
not
in
messages
[
-
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
prompt_ids
+=
self
.
get_thought_word_ids
(
tokenizer
)
...
...
@@ -442,8 +448,8 @@ class ReasoningTemplate(Template):
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
for
i
in
range
(
0
,
len
(
messages
),
2
):
if
(
self
.
thought_words
[
0
]
not
in
messages
[
i
+
1
][
"content"
]
and
self
.
thought_words
[
1
]
not
in
messages
[
i
+
1
][
"content"
]
self
.
thought_words
[
0
]
.
strip
()
not
in
messages
[
i
+
1
][
"content"
]
and
self
.
thought_words
[
1
]
.
strip
()
not
in
messages
[
i
+
1
][
"content"
]
):
# add empty cot
if
not
self
.
enable_thinking
:
# do not compute loss
encoded_messages
[
i
]
+=
self
.
get_thought_word_ids
(
tokenizer
)
...
...
@@ -468,6 +474,7 @@ def register_template(
default_system
:
str
=
""
,
stop_words
:
Optional
[
list
[
str
]]
=
None
,
thought_words
:
Optional
[
tuple
[
str
,
str
]]
=
None
,
tool_call_words
:
Optional
[
tuple
[
str
,
str
]]
=
None
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
False
,
...
...
@@ -518,7 +525,8 @@ def register_template(
format_prefix
=
format_prefix
or
default_prefix_formatter
,
default_system
=
default_system
,
stop_words
=
stop_words
or
[],
thought_words
=
thought_words
or
(
"<think>"
,
"</think>"
),
thought_words
=
thought_words
or
(
"<think>
\n
"
,
"
\n
</think>
\n\n
"
),
tool_call_words
=
tool_call_words
or
(
"<tool_call>"
,
"</tool_call>"
),
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
replace_jinja_template
=
replace_jinja_template
,
...
...
@@ -579,7 +587,8 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
format_prefix
=
EmptyFormatter
(
slots
=
[
prefix
])
if
prefix
else
EmptyFormatter
(),
default_system
=
default_system
,
stop_words
=
[],
thought_words
=
(
"<think>"
,
"</think>"
),
thought_words
=
(
"<think>
\n
"
,
"
\n
</think>
\n\n
"
),
tool_call_words
=
(
"<tool_call>"
,
"</tool_call>"
),
efficient_eos
=
False
,
replace_eos
=
False
,
replace_jinja_template
=
False
,
...
...
@@ -616,7 +625,14 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
logger
.
info_rank0
(
f
"Using default system message:
{
data_args
.
default_system
}
."
)
template
.
default_system
=
data_args
.
default_system
template
.
enable_thinking
=
data_args
.
enable_thinking
if
isinstance
(
template
,
ReasoningTemplate
):
logger
.
warning_rank0
(
"You are using reasoning template, "
"please add `_nothink` suffix if the model is not a reasoning model. "
"e.g., qwen3_vl_nothink"
)
template
.
enable_thinking
=
data_args
.
enable_thinking
template
.
fix_special_tokens
(
tokenizer
)
template
.
fix_jinja_template
(
tokenizer
)
return
template
...
...
@@ -679,6 +695,23 @@ register_template(
)
register_template
(
name
=
"bailing_v2"
,
format_user
=
StringFormatter
(
slots
=
[
"<role>HUMAN</role>{{content}}<|role_end|><role>ASSISTANT</role>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<role>SYSTEM</role>{{content}}<|role_end|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|role_end|>"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<role>OBSERVATION</role>
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|role_end|><role>ASSISTANT</role>"
]
),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|role_end|>"
],
tool_format
=
"ling"
),
format_tools
=
ToolFormatter
(
tool_format
=
"ling"
),
stop_words
=
[
"<|endoftext|>"
],
efficient_eos
=
True
,
)
register_template
(
name
=
"belle"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Belle: "
]),
...
...
@@ -894,12 +927,64 @@ register_template(
)
register_template
(
name
=
"dots_ocr"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>{{content}}<|endofuser|><|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|endofassistant|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>{{content}}<|endofsystem|>
\n
"
]),
stop_words
=
[
"<|endofassistant|>"
],
efficient_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|imgpad|>"
,
video_token
=
"<|vidpad|>"
,
vision_bos_token
=
"<|img|>"
,
vision_eos_token
=
"<|endofimg|>"
,
),
)
register_template
(
name
=
"empty"
,
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
]),
)
# copied from chatml template
register_template
(
name
=
"ernie"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n\n
<|im_start|>assistant
\n
"
]),
default_system
=
"<global_setting>
\n
think_mode=True
\n
</global_setting>"
,
stop_words
=
[
"<|im_end|>"
],
)
register_template
(
name
=
"ernie_nothink"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n
Assistant: "
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end_of_sentence|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[
"<|begin_of_sentence|>"
]),
stop_words
=
[
"<|end_of_sentence|>"
],
)
register_template
(
name
=
"ernie_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
Assistant: {{content}}<|end_of_sentence|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
stop_words
=
[
"<|end_of_sentence|>"
],
replace_eos
=
True
,
replace_jinja_template
=
True
,
template_class
=
ReasoningTemplate
,
mm_plugin
=
get_mm_plugin
(
name
=
"ernie_vl"
,
image_token
=
"<|IMAGE_PLACEHOLDER|>"
,
video_token
=
"<|VIDEO_PLACEHOLDER|>"
),
)
register_template
(
name
=
"exaone"
,
format_user
=
StringFormatter
(
slots
=
[
"[|user|]{{content}}
\n
[|assistant|]"
]),
...
...
@@ -1014,6 +1099,22 @@ register_template(
)
# copied from glm4 template
register_template
(
name
=
"glm4_moe"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4_moe"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4_moe"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
# copied from glm4 template
register_template
(
name
=
"glm4v"
,
...
...
@@ -1031,6 +1132,23 @@ register_template(
)
# copied from glm4 template
register_template
(
name
=
"glm4_5v"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4_moe"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4_moe"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
,
"</answer>"
],
efficient_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"glm4v"
,
image_token
=
"<|image|>"
,
video_token
=
"<|video|>"
),
template_class
=
ReasoningTemplate
,
)
# copied from glm4 template
register_template
(
name
=
"glmz1"
,
...
...
@@ -1047,6 +1165,18 @@ register_template(
)
register_template
(
name
=
"gpt_oss"
,
format_user
=
StringFormatter
(
slots
=
[
"<|start|>user<|message|>{{content}}<|end|><|start|>assistant"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start|>system<|message|>{{content}}<|end|>"
]),
default_system
=
"You are ChatGPT, a large language model trained by OpenAI."
,
thought_words
=
(
"<|channel|>analysis<|message|>"
,
"<|end|><|start|>assistant<|channel|>final<|message|>"
),
efficient_eos
=
True
,
template_class
=
ReasoningTemplate
,
)
register_template
(
name
=
"granite3"
,
format_user
=
StringFormatter
(
...
...
@@ -1071,6 +1201,25 @@ register_template(
)
register_template
(
name
=
"granite4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>
\n
<|start_of_role|>assistant<|end_of_role|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end_of_text|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|end_of_text|>
\n
"
],
tool_format
=
"default"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|start_of_role|>tool<|end_of_role|>{{content}}<|end_of_text|>
\n
<|start_of_role|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"default"
),
stop_words
=
[
"<|end_of_text|>"
],
default_system
=
"You are Granite, developed by IBM. You are a helpful AI assistant."
,
)
register_template
(
name
=
"index"
,
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
...
...
@@ -1081,10 +1230,10 @@ register_template(
register_template
(
name
=
"hunyuan"
,
format_user
=
StringFormatter
(
slots
=
[
"
<|bos|>user
\n
{{content}}<|e
os|>
\n
<|bos|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eos|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"
<|bos|>system
\n
{{content}}<|e
os|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[
"<|
bos
|>"
]),
format_user
=
StringFormatter
(
slots
=
[
"{{content}}<|e
xtra_0|>
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eos|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}<|e
xtra_4|>
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[
"<|
startoftext
|>"
]),
stop_words
=
[
"<|eos|>"
],
)
...
...
@@ -1137,6 +1286,35 @@ register_template(
)
register_template
(
name
=
"intern_s1"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"intern_vl"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
# copied from qwen template
register_template
(
name
=
"keye_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
template_class
=
ReasoningTemplate
,
)
register_template
(
name
=
"kimi_vl"
,
format_user
=
StringFormatter
(
...
...
@@ -1432,6 +1610,26 @@ register_template(
template_class
=
ReasoningTemplate
,
)
# copied from qwen template
register_template
(
name
=
"mimo_v2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are MiMo, a helpful AI assistant engineered by Xiaomi."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
thought_words
=
(
"<think>"
,
"</think>"
),
template_class
=
ReasoningTemplate
,
)
# copied from qwen2vl
register_template
(
name
=
"mimo_vl"
,
...
...
@@ -1470,11 +1668,48 @@ register_template(
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
default_system
=
"You are
Qwen, created by Alibaba Cloud. You are a helpful assistan
t."
,
default_system
=
"You are
a helpful assistant. You can accept audio and text input and output voice and tex
t."
,
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
,
audio_token
=
"<audio>"
),
)
register_template
(
name
=
"minimax1"
,
format_user
=
StringFormatter
(
slots
=
[
"<beginning_of_sentence>user name=user
\n
{{content}}<end_of_sentence>
\n
<beginning_of_sentence>ai name=assistant
\n
"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_sentence>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<beginning_of_sentence>system ai_setting=assistant
\n
{{content}}<end_of_sentence>
\n
"
]
),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<end_of_sentence>
\n
"
],
tool_format
=
"minimax1"
),
format_observation
=
StringFormatter
(
slots
=
[
"<beginning_of_sentence>tool name=tools
\n
{{content}}<end_of_sentence>
\n
<beginning_of_sentence>ai name=assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"minimax1"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<end_of_sentence>"
],
)
register_template
(
name
=
"minimax2"
,
format_user
=
StringFormatter
(
slots
=
[
"]~b]user
\n
{{content}}[e~[
\n
]~b]ai
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}[e~[
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"]~!b[]~b]system
\n
{{content}}[e~[
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}[e~[
\n
"
],
tool_format
=
"minimax2"
),
format_observation
=
StringFormatter
(
slots
=
[
"]~b]tool
\n
<response>{{content}}</response>[e~[
\n
]~b]ai
\n
"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"minimax2"
),
default_system
=
"You are a helpful assistant. Your name is MiniMax-M2.1 and is built by MiniMax."
,
stop_words
=
[
"[e~["
],
template_class
=
ReasoningTemplate
,
)
# mistral tokenizer v3 tekken
register_template
(
name
=
"ministral"
,
...
...
@@ -1515,6 +1750,19 @@ register_template(
)
register_template
(
name
=
"ministral3"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS]{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
template_class
=
Llama2Template
,
mm_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
),
)
register_template
(
name
=
"olmo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
...
...
@@ -1669,6 +1917,22 @@ register_template(
)
# copied from qwen template
register_template
(
name
=
"qwen3_nothink"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
# copied from chatml template
register_template
(
name
=
"qwen2_audio"
,
...
...
@@ -1697,10 +1961,55 @@ register_template(
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
audio_token
=
"<|AUDIO|>"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
name
=
"qwen2_omni"
,
image_token
=
"<|IMAGE|>"
,
video_token
=
"<|VIDEO|>"
,
audio_token
=
"<|AUDIO|>"
,
vision_bos_token
=
"<|vision_bos|>"
,
vision_eos_token
=
"<|vision_eos|>"
,
audio_bos_token
=
"<|audio_bos|>"
,
audio_eos_token
=
"<|audio_eos|>"
,
),
)
register_template
(
name
=
"qwen3_omni"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
,
audio_token
=
"<|audio_pad|>"
),
template_class
=
ReasoningTemplate
,
)
register_template
(
name
=
"qwen3_omni_nothink"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_omni"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
,
audio_token
=
"<|audio_pad|>"
),
)
# copied from qwen template
register_template
(
name
=
"qwen2_vl"
,
...
...
@@ -1719,6 +2028,41 @@ register_template(
)
# copied from qwen template
register_template
(
name
=
"qwen3_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen3_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
template_class
=
ReasoningTemplate
,
)
# copied from qwen template
register_template
(
name
=
"qwen3_vl_nothink"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen3_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
)
register_template
(
name
=
"sailor"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>question
\n
{{content}}<|im_end|>
\n
<|im_start|>answer
\n
"
]),
...
...
@@ -1746,6 +2090,20 @@ register_template(
)
# copied from seed_coder
register_template
(
name
=
"seed_oss"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"user
\n
{{content}}"
,
{
"eos_token"
},
{
"bos_token"
},
"assistant
\n
"
]
),
format_system
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"system
\n
{{content}}"
,
{
"eos_token"
}]),
format_function
=
FunctionFormatter
(
slots
=
[{
"bos_token"
},
"
\n
{{content}}"
,
{
"eos_token"
}],
tool_format
=
"seed_oss"
),
format_tools
=
ToolFormatter
(
tool_format
=
"seed_oss"
),
template_class
=
ReasoningTemplate
,
thought_words
=
(
"<seed:think>"
,
"</seed:think>"
),
)
# copied from llama3 template
register_template
(
name
=
"skywork_o1"
,
...
...
src/llamafactory/data/tool_utils.py
View file @
ca625f43
...
...
@@ -42,6 +42,18 @@ GLM4_TOOL_PROMPT = (
"你的任务是针对用户的问题和要求提供适当的答复和支持。
\n\n
# 可用工具{tool_text}"
)
GLM4_MOE_TOOL_PROMPT
=
(
"
\n\n
# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>{tool_text}"
"
\n
</tools>
\n\n
For each function call, output the function name and arguments within the following XML format:"
"
\n
<tool_call>{{function-name}}"
"
\n
<arg_key>{{arg-key-1}}</arg_key>"
"
\n
<arg_value>{{arg-value-1}}</arg_value>"
"
\n
<arg_key>{{arg-key-2}}</arg_key>"
"
\n
<arg_value>{{arg-value-2}}</arg_value>"
"
\n
...
\n
</tool_call>
\n
"
)
LLAMA3_TOOL_PROMPT
=
(
"Cutting Knowledge Date: December 2023
\n
Today Date: {date}
\n\n
"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
...
...
@@ -49,6 +61,21 @@ LLAMA3_TOOL_PROMPT = (
"Do not use variables.
\n\n
{tool_text}"
)
MINIMAX_M1_TOOL_PROMPT
=
(
"You are provided with these tools:
\n
<tools>
\n
{tool_text}</tools>
\n\n
"
"If you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and "
"json-object of arguments, following the format below:
\n
<tool_calls>
\n
"
"""{{"name": <tool-name-1>, "arguments": <args-json-object-1>}}
\n
...
\n
</tool_calls>"""
)
MINIMAX_M2_TOOL_PROMPT
=
(
"
\n\n
# Tools
\n\n
You may call one or more tools to assist with the user query.
\n
"
"Here are the tools available in JSONSchema format:
\n\n
<tools>
\n
{tool_text}</tools>
\n\n
"
"When making tool calls, use XML format to invoke tools and pass parameters:
\n
"
"""
\n
<minimax:tool_call>
\n
<invoke name="tool-name-1">
\n
<parameter name="param-key-1">param-value-1</parameter>
\n
"""
"""<parameter name="param-key-2">param-value-2</parameter>
\n
...
\n
</invoke>
\n
</minimax:tool_call>"""
)
QWEN_TOOL_PROMPT
=
(
"
\n\n
# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>{tool_text}"
...
...
@@ -57,6 +84,23 @@ QWEN_TOOL_PROMPT = (
""""arguments": <args-json-object>}}
\n
</tool_call>"""
)
SEED_TOOL_PROMPT
=
(
"system
\n
You are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
"Tool List:
\n
You are authorized to use the following tools (described in JSON Schema format). Before performing "
"any task, you must decide how to call them based on the descriptions and parameters of these tools.{tool_text}
\n
"
"工具调用请遵循如下格式:
\n
<seed:tool_call>
\n
<function=example_function_name>
\n
<parameter=example_parameter_1>value_1"
"</parameter>
\n
<parameter=example_parameter_2>This is the value for the second parameter
\n
that can span
\n
multiple "
"lines</parameter>
\n
</function>
\n
</seed:tool_call>
\n
"
)
LING_TOOL_PROMPT
=
(
"# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>{tool_text}"
"
\n
</tools>
\n\n
For each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{{"name": <function-name>, """
""""arguments": <args-json-object>}}
\n
</tool_call>"""
)
@
dataclass
class
ToolUtils
(
ABC
):
...
...
@@ -224,6 +268,109 @@ class Llama3ToolUtils(ToolUtils):
return
content
class
MiniMaxM1ToolUtils
(
ToolUtils
):
r
"""MiniMax-M1 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
tool_text
+=
json
.
dumps
(
tool
,
ensure_ascii
=
False
)
+
"
\n
"
return
MINIMAX_M1_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
func
in
functions
:
name
,
arguments
=
func
.
name
,
json
.
loads
(
func
.
arguments
)
function_texts
.
append
(
json
.
dumps
({
"name"
:
name
,
"arguments"
:
arguments
},
ensure_ascii
=
False
))
return
"<tool_calls>
\n
"
+
"
\n
"
.
join
(
function_texts
)
+
"
\n
</tool_calls>"
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"<tool_calls>\s*(.+?)\s*</tool_calls>"
,
re
.
DOTALL
)
tool_match
=
re
.
search
(
regex
,
content
)
if
not
tool_match
:
return
content
tool_calls_content
=
tool_match
.
group
(
1
)
results
=
[]
for
line
in
tool_calls_content
.
split
(
"
\n
"
):
line
=
line
.
strip
()
if
not
line
:
continue
try
:
tool_call
=
json
.
loads
(
line
)
results
.
append
(
FunctionCall
(
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)))
except
json
.
JSONDecodeError
:
continue
return
results
class
MiniMaxM2ToolUtils
(
ToolUtils
):
r
"""MiniMax-M2 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
tool
=
tool
.
get
(
"function"
,
""
)
if
tool
.
get
(
"type"
)
==
"function"
else
tool
tool_text
+=
"<tool>"
+
json
.
dumps
(
tool
,
ensure_ascii
=
False
)
+
"</tool>
\n
"
return
MINIMAX_M2_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_texts
=
[]
for
func
in
functions
:
name
,
arguments
=
func
.
name
,
json
.
loads
(
func
.
arguments
)
prompt
=
f
'<invoke name="
{
name
}
">'
for
key
,
value
in
arguments
.
items
():
prompt
+=
f
'
\n
<parameter name="
{
key
}
">'
if
not
isinstance
(
value
,
str
):
value
=
json
.
dumps
(
value
,
ensure_ascii
=
False
)
prompt
+=
value
+
"</parameter>"
prompt
+=
"
\n
</invoke>"
function_texts
.
append
(
prompt
)
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"<minimax:tool_call>\s*(.+?)\s*</minimax:tool_call>"
,
re
.
DOTALL
)
tool_match
=
re
.
search
(
regex
,
content
)
if
not
tool_match
:
return
content
tool_calls_content
=
tool_match
.
group
(
1
)
invoke_regex
=
re
.
compile
(
r
"<invoke name=\"(.*?)\">(.*?)</invoke>"
,
re
.
DOTALL
)
results
=
[]
for
func_name
,
params_block
in
re
.
findall
(
invoke_regex
,
tool_calls_content
):
args_dict
=
{}
param_pattern
=
re
.
compile
(
r
"<parameter name=\"(.*?)\">(.*?)</parameter>"
,
re
.
DOTALL
)
for
key
,
raw_value
in
re
.
findall
(
param_pattern
,
params_block
):
value
=
raw_value
.
strip
()
try
:
parsed_value
=
json
.
loads
(
value
)
except
json
.
JSONDecodeError
:
parsed_value
=
raw_value
args_dict
[
key
]
=
parsed_value
results
.
append
(
FunctionCall
(
func_name
.
strip
(),
json
.
dumps
(
args_dict
,
ensure_ascii
=
False
)))
return
results
class
MistralToolUtils
(
ToolUtils
):
r
"""Mistral v0.3 tool using template."""
...
...
@@ -303,12 +450,113 @@ class QwenToolUtils(ToolUtils):
return
results
class
GLM4MOEToolUtils
(
QwenToolUtils
):
r
"""GLM-4-MOE tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
return
GLM4_MOE_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_json
=
[
{
"func_name"
:
name
,
"func_key_values"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
]
function_texts
=
[]
for
func
in
function_json
:
prompt
=
"
\n
<tool_call>"
+
func
[
"func_name"
]
for
key
,
value
in
func
[
"func_key_values"
].
items
():
prompt
+=
"
\n
<arg_key>"
+
key
+
"</arg_key>"
if
not
isinstance
(
value
,
str
):
value
=
json
.
dumps
(
value
,
ensure_ascii
=
False
)
prompt
+=
"
\n
<arg_value>"
+
value
+
"</arg_value>"
function_texts
.
append
(
prompt
)
return
"
\n
"
.
join
(
function_texts
)
class
SeedToolUtils
(
ToolUtils
):
r
"""Seed tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
return
SEED_TOOL_PROMPT
.
format
(
tool_text
=
"
\n
"
+
json
.
dumps
(
tools
,
ensure_ascii
=
False
))
@
override
@
staticmethod
def
function_formatter
(
functions
:
list
[
"FunctionCall"
])
->
str
:
function_json
=
[
{
"func_name"
:
name
,
"func_key_values"
:
json
.
loads
(
arguments
)}
for
name
,
arguments
in
functions
]
function_texts
=
[]
for
func
in
function_json
:
prompt
=
"
\n
<seed:tool_call>
\n
<function="
+
func
[
"func_name"
]
for
key
,
value
in
func
[
"func_key_values"
].
items
():
prompt
+=
"
\n
<parameter="
+
key
+
">"
if
not
isinstance
(
value
,
str
):
value
=
json
.
dumps
(
value
,
ensure_ascii
=
False
)
prompt
+=
value
+
"</parameter>"
prompt
+=
"
\n
</function>
\n
</seed:tool_call>"
function_texts
.
append
(
prompt
)
return
"
\n
"
.
join
(
function_texts
)
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
results
=
[]
regex
=
re
.
compile
(
r
"<seed:tool_call>\s*<function=\s*([^\s<]+)\s*(.*?)\s*</function>\s*</seed:tool_call>"
,
re
.
DOTALL
)
for
func_name
,
params_block
in
re
.
findall
(
regex
,
content
):
args_dict
=
{}
param_pattern
=
re
.
compile
(
r
"<parameter=(.*?)>(.*?)</parameter>"
,
re
.
DOTALL
)
for
key
,
raw_value
in
re
.
findall
(
param_pattern
,
params_block
.
strip
()):
value
=
raw_value
.
strip
()
try
:
parsed_value
=
json
.
loads
(
value
)
except
json
.
JSONDecodeError
:
parsed_value
=
raw_value
args_dict
[
key
]
=
parsed_value
results
.
append
(
FunctionCall
(
func_name
.
strip
(),
json
.
dumps
(
args_dict
,
ensure_ascii
=
False
)))
return
results
class
LingToolUtils
(
QwenToolUtils
):
r
"""Ling v2 tool using template."""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
list
[
dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
tool
if
tool
.
get
(
"type"
)
==
"function"
else
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
return
LING_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
+
"
\n
"
+
"detailed thinking off"
TOOLS
=
{
"default"
:
DefaultToolUtils
(),
"glm4"
:
GLM4ToolUtils
(),
"llama3"
:
Llama3ToolUtils
(),
"minimax1"
:
MiniMaxM1ToolUtils
(),
"minimax2"
:
MiniMaxM2ToolUtils
(),
"mistral"
:
MistralToolUtils
(),
"qwen"
:
QwenToolUtils
(),
"glm4_moe"
:
GLM4MOEToolUtils
(),
"seed_oss"
:
SeedToolUtils
(),
"ling"
:
LingToolUtils
(),
}
...
...
src/llamafactory/extras/constants.py
View file @
ca625f43
...
...
@@ -15,7 +15,6 @@
import
os
from
collections
import
OrderedDict
,
defaultdict
from
enum
import
Enum
,
unique
from
typing
import
Optional
from
peft.utils
import
SAFETENSORS_WEIGHTS_NAME
as
SAFE_ADAPTER_WEIGHTS_NAME
from
peft.utils
import
WEIGHTS_NAME
as
ADAPTER_WEIGHTS_NAME
...
...
@@ -56,13 +55,27 @@ LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG
=
"llamaboard_config.yaml"
METHODS
=
[
"full"
,
"freeze"
,
"lora"
]
MCA_SUPPORTED_MODELS
=
{
"deepseek_v3"
,
"llama"
,
"mistral"
,
"mixtral"
,
"qwen2"
,
"qwen2_vl"
,
"qwen2_5_vl"
,
"qwen3_vl"
,
"qwen3"
,
"qwen3_moe"
,
"qwen3_next"
,
}
METHODS
=
[
"full"
,
"freeze"
,
"lora"
,
"oft"
]
MOD_SUPPORTED_MODELS
=
{
"bloom"
,
"falcon"
,
"gemma"
,
"llama"
,
"mistral"
,
"mixtral"
,
"phi"
,
"starcoder2"
}
MULTIMODAL_SUPPORTED_MODELS
=
set
()
PEFT_METHODS
=
{
"lora"
}
PEFT_METHODS
=
{
"lora"
,
"oft"
}
RUNNING_LOG
=
"running_log.txt"
...
...
@@ -101,12 +114,14 @@ class AttentionFunction(str, Enum):
DISABLED
=
"disabled"
SDPA
=
"sdpa"
FA2
=
"fa2"
FA3
=
"fa3"
class
EngineName
(
str
,
Enum
):
HF
=
"huggingface"
VLLM
=
"vllm"
SGLANG
=
"sglang"
KT
=
"ktransformers"
class
DownloadSource
(
str
,
Enum
):
...
...
@@ -126,6 +141,8 @@ class QuantizationMethod(str, Enum):
QUANTO
=
"quanto"
EETQ
=
"eetq"
HQQ
=
"hqq"
MXFP4
=
"mxfp4"
FP8
=
"fp8"
class
RopeScaling
(
str
,
Enum
):
...
...
@@ -137,13 +154,13 @@ class RopeScaling(str, Enum):
def
register_model_group
(
models
:
dict
[
str
,
dict
[
DownloadSource
,
str
]],
template
:
Optional
[
str
]
=
None
,
template
:
str
|
None
=
None
,
multimodal
:
bool
=
False
,
)
->
None
:
for
name
,
path
in
models
.
items
():
SUPPORTED_MODELS
[
name
]
=
path
if
template
is
not
None
and
(
any
(
suffix
in
name
for
suffix
in
(
"-Chat"
,
"-Distill"
,
"-Instruct"
))
or
multimodal
any
(
suffix
in
name
for
suffix
in
(
"-Chat"
,
"-Distill"
,
"-Instruct"
,
"-Thinking"
))
or
multimodal
):
DEFAULT_TEMPLATE
[
name
]
=
template
...
...
@@ -276,7 +293,7 @@ register_model_group(
register_model_group
(
models
=
{
"ChatGLM2-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/chatglm2-6b"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/chatglm2-6b"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/chatglm2-6b"
,
}
},
...
...
@@ -287,11 +304,11 @@ register_model_group(
register_model_group
(
models
=
{
"ChatGLM3-6B-Base"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/chatglm3-6b-base"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/chatglm3-6b-base"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/chatglm3-6b-base"
,
},
"ChatGLM3-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/chatglm3-6b"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/chatglm3-6b"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/chatglm3-6b"
,
},
},
...
...
@@ -333,7 +350,7 @@ register_model_group(
register_model_group
(
models
=
{
"CodeGeeX4-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/codegeex4-all-9b"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/codegeex4-all-9b"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/codegeex4-all-9b"
,
},
},
...
...
@@ -600,6 +617,68 @@ register_model_group(
)
register_model_group
(
models
=
{
"dots.ocr"
:
{
DownloadSource
.
DEFAULT
:
"rednote-hilab/dots.ocr"
,
DownloadSource
.
MODELSCOPE
:
"rednote-hilab/dots.ocr"
,
},
},
template
=
"dots_ocr"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"ERNIE-4.5-21B-A3B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"baidu/ERNIE-4.5-21B-A3B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"PaddlePaddle/ERNIE-4.5-21B-A3B-Thinking"
,
},
},
template
=
"ernie"
,
)
register_model_group
(
models
=
{
"ERNIE-4.5-0.3B-PT"
:
{
DownloadSource
.
DEFAULT
:
"baidu/ERNIE-4.5-0.3B-PT"
,
DownloadSource
.
MODELSCOPE
:
"PaddlePaddle/ERNIE-4.5-0.3B-PT"
,
},
"ERNIE-4.5-21B-A3B-PT"
:
{
DownloadSource
.
DEFAULT
:
"baidu/ERNIE-4.5-21B-A3B-PT"
,
DownloadSource
.
MODELSCOPE
:
"PaddlePaddle/ERNIE-4.5-21B-A3B-PT"
,
},
"ERNIE-4.5-300B-A47B-PT"
:
{
DownloadSource
.
DEFAULT
:
"baidu/ERNIE-4.5-300B-A47B-PT"
,
DownloadSource
.
MODELSCOPE
:
"PaddlePaddle/ERNIE-4.5-300B-A47B-PT"
,
},
},
template
=
"ernie_nothink"
,
)
register_model_group
(
models
=
{
"ERNIE-4.5-VL-28B-A3B-PT"
:
{
DownloadSource
.
DEFAULT
:
"baidu/ERNIE-4.5-VL-28B-A3B-PT"
,
DownloadSource
.
MODELSCOPE
:
"PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT"
,
},
"ERNIE-4.5-VL-28B-A3B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"baidu/ERNIE-4.5-VL-28B-A3B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Thinking"
,
},
"ERNIE-4.5-VL-424B-A47B-Base-PT"
:
{
DownloadSource
.
DEFAULT
:
"baidu/ERNIE-4.5-VL-424B-A47B-PT"
,
DownloadSource
.
MODELSCOPE
:
"PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT"
,
},
},
template
=
"ernie_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"EXAONE-3.0-7.8B-Instruct"
:
{
...
...
@@ -644,6 +723,7 @@ register_model_group(
template
=
"falcon"
,
)
register_model_group
(
models
=
{
"Falcon-H1-0.5B-Base"
:
{
...
...
@@ -756,10 +836,18 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-27b-it"
,
},
"Gemma-3-270M"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-270m"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-270m"
,
},
"Gemma-3-1B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-pt"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-pt"
,
},
"Gemma-3-270M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-270m-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-270m-it"
,
},
"Gemma-3-1B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-3-1b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-3-1b-it"
,
...
...
@@ -807,6 +895,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"google/medgemma-4b-it"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-4b-it"
,
},
"MedGemma-27B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"google/medgemma-27b-text-it"
,
DownloadSource
.
MODELSCOPE
:
"google/medgemma-27b-text-it"
,
},
},
template
=
"gemma3"
,
multimodal
=
True
,
...
...
@@ -840,28 +932,28 @@ register_model_group(
register_model_group
(
models
=
{
"GLM-4-9B"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/glm-4-9b"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/glm-4-9b"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/glm-4-9b"
,
},
"GLM-4-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/glm-4-9b-chat"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/glm-4-9b-chat"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/glm-4-9b-chat"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/glm-4-9b-chat"
,
},
"GLM-4-9B-1M-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/glm-4-9b-chat-1m"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/glm-4-9b-chat-1m"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/glm-4-9b-chat-1m"
,
},
"GLM-4-0414-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/GLM-4-9B-0414"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/GLM-4-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-9B-0414"
,
},
"GLM-4-0414-32B-Base"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/GLM-4-32B-Base-0414"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/GLM-4-32B-Base-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-Base-0414"
,
},
"GLM-4-0414-32B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/GLM-4-32B-0414"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/GLM-4-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4-32B-0414"
,
},
},
...
...
@@ -872,11 +964,11 @@ register_model_group(
register_model_group
(
models
=
{
"GLM-4.1V-9B-Base"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/GLM-4.1V-9B-Base"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/GLM-4.1V-9B-Base"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.1V-9B-Base"
,
},
"GLM-4.1V-9B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/GLM-4.1V-9B-Thinking"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/GLM-4.1V-9B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.1V-9B-Thinking"
,
},
},
...
...
@@ -885,14 +977,57 @@ register_model_group(
)
register_model_group
(
models
=
{
"GLM-4.5-Air-Base"
:
{
DownloadSource
.
DEFAULT
:
"zai-org/GLM-4.5-Air-Base"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.5-Air-Base"
,
},
"GLM-4.5-Base"
:
{
DownloadSource
.
DEFAULT
:
"zai-org/GLM-4.5-Base"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.5-Base"
,
},
"GLM-4.5-Air-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"zai-org/GLM-4.5-Air"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.5-Air"
,
},
"GLM-4.5-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"zai-org/GLM-4.5"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.5"
,
},
},
template
=
"glm4_moe"
,
)
register_model_group
(
models
=
{
"GLM-4.5V-Air-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"zai-org/GLM-4.5V"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.5V"
,
},
"GLM-4.6V"
:
{
DownloadSource
.
DEFAULT
:
"zai-org/GLM-4.6V"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.6V"
,
},
"GLM-4.6V-Flash"
:
{
DownloadSource
.
DEFAULT
:
"zai-org/GLM-4.6V-Flash"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-4.6V-Flash"
,
},
},
template
=
"glm4_5v"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"GLM-Z1-0414-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/GLM-Z1-9B-0414"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/GLM-Z1-9B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-9B-0414"
,
},
"GLM-Z1-0414-32B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"
THUDM
/GLM-Z1-32B-0414"
,
DownloadSource
.
DEFAULT
:
"
zai-org
/GLM-Z1-32B-0414"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/GLM-Z1-32B-0414"
,
},
},
...
...
@@ -922,6 +1057,55 @@ register_model_group(
)
register_model_group
(
models
=
{
"GPT-OSS-20B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"openai/gpt-oss-20b"
,
DownloadSource
.
MODELSCOPE
:
"openai/gpt-oss-20b"
,
},
"GPT-OSS-120B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"openai/gpt-oss-120b"
,
DownloadSource
.
MODELSCOPE
:
"openai/gpt-oss-120b"
,
},
},
template
=
"gpt_oss"
,
)
register_model_group
(
models
=
{
"MiniMax-Text-01-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"MiniMaxAI/MiniMax-Text-01-hf"
,
DownloadSource
.
MODELSCOPE
:
"MiniMaxAI/MiniMax-Text-01"
,
},
"MiniMax-M1-40k-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"MiniMaxAI/MiniMax-M1-40k-hf"
,
DownloadSource
.
MODELSCOPE
:
"MiniMaxAI/MiniMax-M1-40k-hf"
,
},
"MiniMax-M1-80k-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"MiniMaxAI/MiniMax-M1-80k-hf"
,
DownloadSource
.
MODELSCOPE
:
"MiniMaxAI/MiniMax-M1-80k-hf"
,
},
},
template
=
"minimax1"
,
)
register_model_group
(
models
=
{
"MiniMax-M2-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"MiniMaxAI/MiniMax-M2"
,
DownloadSource
.
MODELSCOPE
:
"MiniMaxAI/MiniMax-M2"
,
},
"MiniMax-M2.1-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"MiniMaxAI/MiniMax-M2.1"
,
DownloadSource
.
MODELSCOPE
:
"MiniMaxAI/MiniMax-M2.1"
,
},
},
template
=
"minimax2"
,
)
register_model_group
(
models
=
{
"Granite-3.0-1B-A400M-Base"
:
{
...
...
@@ -1029,12 +1213,27 @@ register_model_group(
)
register_model_group
(
models
=
{
"Granite-4.0-tiny-preview"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-4.0-tiny-preview"
,
DownloadSource
.
MODELSCOPE
:
"ibm-granite/granite-4.0-tiny-preview"
,
},
},
template
=
"granite4"
,
)
register_model_group
(
models
=
{
"Hunyuan-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"tencent/Hunyuan-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Hunyuan-7B-Instruct"
,
},
"Hunyuan-MT-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"tencent/Hunyuan-MT-7B"
,
DownloadSource
.
MODELSCOPE
:
"Tencent-Hunyuan/Hunyuan-MT-7B"
,
},
},
template
=
"hunyuan"
,
)
...
...
@@ -1185,12 +1384,52 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3-78B-hf"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3-78B-hf"
,
},
"InternVL3.5-1B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3_5-1B-HF"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3_5-1B-HF"
,
},
"InternVL3.5-2B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3_5-2B-HF"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3_5-2B-HF"
,
},
"InternVL3.5-4B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3_5-4B-HF"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3_5-4B-HF"
,
},
"InternVL3.5-8B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3_5-8B-HF"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3_5-8B-HF"
,
},
"InternVL3.5-14B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3_5-14B-HF"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3_5-14B-HF"
,
},
"InternVL3.5-30B-A3B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3_5-30B-A3B-HF"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3_5-30B-A3B-HF"
,
},
"InternVL3.5-38B-hf"
:
{
DownloadSource
.
DEFAULT
:
"OpenGVLab/InternVL3_5-38B-HF"
,
DownloadSource
.
MODELSCOPE
:
"OpenGVLab/InternVL3_5-38B-HF"
,
},
},
template
=
"intern_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Intern-S1-mini"
:
{
DownloadSource
.
DEFAULT
:
"internlm/Intern-S1-mini"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/Intern-S1-mini"
,
}
},
template
=
"intern_s1"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Jamba-v0.1"
:
{
...
...
@@ -1201,6 +1440,18 @@ register_model_group(
)
register_model_group
(
models
=
{
"Keye-VL-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Kwai-Keye/Keye-VL-8B-Preview"
,
DownloadSource
.
MODELSCOPE
:
"Kwai-Keye/Keye-VL-8B-Preview"
,
},
},
template
=
"keye_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Kimi-Dev-72B-Instruct"
:
{
...
...
@@ -1589,20 +1840,51 @@ register_model_group(
register_model_group
(
models
=
{
"MiMo-7B-VL-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
"MiMo-V2-Flash-Base"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-V2-Flash-Base"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-V2-Flash-Base"
,
},
"MiMo-V2-Flash"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-V2-Flash"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-V2-Flash"
,
},
},
template
=
"mimo_v2"
,
)
register_model_group
(
models
=
{
"MiMo-7B-VL-RL"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-RL"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-RL"
,
},
"MiMo-VL-7B-RL-2508"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-RL-2508"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-RL-2508"
,
},
},
template
=
"mimo_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"MiMo-7B-VL-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-SFT"
,
},
"MiMo-VL-7B-SFT-2508"
:
{
DownloadSource
.
DEFAULT
:
"XiaomiMiMo/MiMo-VL-7B-SFT-2508"
,
DownloadSource
.
MODELSCOPE
:
"XiaomiMiMo/MiMo-VL-7B-SFT-2508"
,
},
},
template
=
"qwen2_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"MiniCPM-2B-SFT-Chat"
:
{
...
...
@@ -1640,6 +1922,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM4-8B"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM4-8B"
,
},
"MiniCPM4.1-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM4.1-8B"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM4.1-8B"
,
},
},
template
=
"cpm4"
,
)
...
...
@@ -1647,7 +1933,7 @@ register_model_group(
register_model_group
(
models
=
{
"MiniCPM-o-2
_
6"
:
{
"MiniCPM-o-2
.
6"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-o-2_6"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-o-2_6"
,
},
...
...
@@ -1659,7 +1945,7 @@ register_model_group(
register_model_group
(
models
=
{
"MiniCPM-V-2
_
6"
:
{
"MiniCPM-V-2
.
6"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-V-2_6"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-V-2_6"
,
},
...
...
@@ -1669,6 +1955,30 @@ register_model_group(
)
register_model_group
(
models
=
{
"MiniCPM-V-4"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-V-4"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-V-4"
,
},
},
template
=
"minicpm_v"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"MiniCPM-V-4.5"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-V-4_5"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-V-4_5"
,
},
},
template
=
"minicpm_v"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Ministral-8B-Instruct-2410"
:
{
...
...
@@ -1718,6 +2028,37 @@ register_model_group(
template
=
"mistral"
,
)
register_model_group
(
models
=
{
"Ministral-3-3B-Base-2512"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Ministral-3-3B-Base-2512"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Ministral-3-3B-Base-2512"
,
},
"Ministral-3-8B-Base-2512"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Ministral-3-8B-Base-2512"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Ministral-3-8B-Base-2512"
,
},
"Ministral-3-14B-Base-2512"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Ministral-3-14B-Base-2512"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Ministral-3-14B-Base-2512"
,
},
"Ministral-3-3B-Instruct-2512"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Ministral-3-3B-Instruct-2512"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Ministral-3-3B-Instruct-2512"
,
},
"Ministral-3-8B-Instruct-2512"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Ministral-3-8B-Instruct-2512"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Ministral-3-8B-Instruct-2512"
,
},
"Ministral-3-14B-Instruct-2512"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Ministral-3-14B-Instruct-2512"
,
DownloadSource
.
MODELSCOPE
:
"mistralai/Ministral-3-14B-Instruct-2512"
,
},
},
template
=
"ministral3"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
...
...
@@ -1777,6 +2118,37 @@ register_model_group(
)
register_model_group
(
models
=
{
"MobileLLM-R1-140M-Base"
:
{
DownloadSource
.
DEFAULT
:
"facebook/MobileLLM-R1-140M-base"
,
DownloadSource
.
MODELSCOPE
:
"facebook/MobileLLM-R1-140M-base"
,
},
"MobileLLM-R1-360M-Base"
:
{
DownloadSource
.
DEFAULT
:
"facebook/MobileLLM-R1-360M-base"
,
DownloadSource
.
MODELSCOPE
:
"facebook/MobileLLM-R1-360M-base"
,
},
"MobileLLM-R1-950M-Base"
:
{
DownloadSource
.
DEFAULT
:
"facebook/MobileLLM-R1-950M-base"
,
DownloadSource
.
MODELSCOPE
:
"facebook/MobileLLM-R1-950M-base"
,
},
"MobileLLM-R1-140M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"facebook/MobileLLM-R1-140M"
,
DownloadSource
.
MODELSCOPE
:
"facebook/MobileLLM-R1-140M"
,
},
"MobileLLM-R1-360M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"facebook/MobileLLM-R1-360M"
,
DownloadSource
.
MODELSCOPE
:
"facebook/MobileLLM-R1-360M"
,
},
"MobileLLM-R1-950M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"facebook/MobileLLM-R1-950M"
,
DownloadSource
.
MODELSCOPE
:
"facebook/MobileLLM-R1-950M"
,
},
},
template
=
"llama3"
,
)
register_model_group
(
models
=
{
"Moonlight-16B-A3B"
:
{
...
...
@@ -2669,75 +3041,114 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-Base"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-Base"
,
},
"Qwen3-0.6B-
Instruct
"
:
{
"Qwen3-0.6B-
Thinking
"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B"
,
},
"Qwen3-1.7B-
Instruct
"
:
{
"Qwen3-1.7B-
Thinking
"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B"
,
},
"Qwen3-4B-
Instruct
"
:
{
"Qwen3-4B-
Thinking
"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B"
,
},
"Qwen3-8B-Instruct"
:
{
"Qwen3-4B-Thinking-2507"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-Thinking-2507"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-Thinking-2507"
,
},
"Qwen3-8B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B"
,
},
"Qwen3-14B-
Instruct
"
:
{
"Qwen3-14B-
Thinking
"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B"
,
},
"Qwen3-32B-
Instruct
"
:
{
"Qwen3-32B-
Thinking
"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-32B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-32B"
,
},
"Qwen3-30B-A3B-
Instruct
"
:
{
"Qwen3-30B-A3B-
Thinking
"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B"
,
},
"Qwen3-235B-A22B-Instruct"
:
{
"Qwen3-30B-A3B-Thinking-2507"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-Thinking-2507"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-Thinking-2507"
,
},
"Qwen3-235B-A22B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B"
,
},
"Qwen3-0.6B-Instruct-GPTQ-Int8"
:
{
"Qwen3-235B-A22B-Thinking-2507"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B-Thinking-2507"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B-Thinking-2507"
,
},
"Qwen3-0.6B-Thinking-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-0.6B-GPTQ-Int8"
,
},
"Qwen3-1.7B-
Instruct
-GPTQ-Int8"
:
{
"Qwen3-1.7B-
Thinking
-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-1.7B-GPTQ-Int8"
,
},
"Qwen3-4B-
Instruct
-AWQ"
:
{
"Qwen3-4B-
Thinking
-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-AWQ"
,
},
"Qwen3-8B-
Instruct
-AWQ"
:
{
"Qwen3-8B-
Thinking
-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-8B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-8B-AWQ"
,
},
"Qwen3-14B-
Instruct
-AWQ"
:
{
"Qwen3-14B-
Thinking
-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-14B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-14B-AWQ"
,
},
"Qwen3-32B-
Instruct
-AWQ"
:
{
"Qwen3-32B-
Thinking
-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-32B-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-32B-AWQ"
,
},
"Qwen3-30B-A3B-
Instruct
-GPTQ-Int4"
:
{
"Qwen3-30B-A3B-
Thinking
-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-GPTQ-Int4"
,
},
"Qwen3-235B-A22B-
Instruct
-GPTQ-Int4"
:
{
"Qwen3-235B-A22B-
Thinking
-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B-GPTQ-Int4"
,
},
"Qwen/Qwen3-Next-80B-A3B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-Next-80B-A3B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-Next-80B-A3B-Thinking"
,
},
},
template
=
"qwen3"
,
)
register_model_group
(
models
=
{
"Qwen3-4B-Instruct-2507"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-4B-Instruct-2507"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-4B-Instruct-2507"
,
},
"Qwen3-30B-A3B-Instruct-2507"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-30B-A3B-Instruct-2507"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-30B-A3B-Instruct-2507"
,
},
"Qwen3-235B-A22B-Instruct-2507"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-235B-A22B-Instruct-2507"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-235B-A22B-Instruct-2507"
,
},
"Qwen3-Next-80B-A3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
},
},
template
=
"qwen3_nothink"
,
)
register_model_group
(
models
=
{
"Qwen2-Audio-7B"
:
{
...
...
@@ -2778,6 +3189,34 @@ register_model_group(
)
register_model_group
(
models
=
{
"Qwen3-Omni-30B-A3B-Captioner"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-Omni-30B-A3B-Captioner"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-Omni-30B-A3B-Captioner"
,
},
"Qwen3-Omni-30B-A3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
,
},
},
template
=
"qwen3_omni_nothink"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Qwen3-Omni-30B-A3B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-Omni-30B-A3B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-Omni-30B-A3B-Thinking"
,
},
},
template
=
"qwen3_omni"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Qwen2-VL-2B"
:
{
...
...
@@ -2880,22 +3319,108 @@ register_model_group(
)
register_model_group
(
models
=
{
"Qwen3-VL-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-2B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-2B-Instruct"
,
},
"Qwen3-VL-4B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-4B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-4B-Instruct"
,
},
"Qwen3-VL-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-8B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-8B-Instruct"
,
},
"Qwen3-VL-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-32B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-32B-Instruct"
,
},
"Qwen3-VL-30B-A3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
},
"Qwen3-VL-235B-A22B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-235B-A22B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-235B-A22B-Instruct"
,
},
},
template
=
"qwen3_vl_nothink"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Qwen3-VL-2B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-2B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-2B-Thinking"
,
},
"Qwen3-VL-4B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-4B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-4B-Thinking"
,
},
"Qwen3-VL-8B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-8B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-8B-Thinking"
,
},
"Qwen3-VL-32B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-32B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-32B-Thinking"
,
},
"Qwen3-VL-30B-A3B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-30B-A3B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-30B-A3B-Thinking"
,
},
"Qwen3-VL-235B-A22B-Thinking"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen3-VL-235B-A22B-Thinking"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen3-VL-235B-A22B-Thinking"
,
},
},
template
=
"qwen3_vl"
,
multimodal
=
True
,
)
register_model_group
(
models
=
{
"Seed-Coder-8B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Base"
,
DownloadSource
.
MODELSCOPE
:
"ByteDance-Seed/Seed-Coder-8B-Base"
,
},
"Seed-Coder-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"ByteDance-Seed/Seed-Coder-8B-Instruct"
,
},
"Seed-Coder-8B-
Instruct-Reason
ing"
:
{
"Seed-Coder-8B-
Think
ing"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16"
,
DownloadSource
.
MODELSCOPE
:
"ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16"
,
},
},
template
=
"seed_coder"
,
)
register_model_group
(
models
=
{
"Seed-OSS-36B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-OSS-36B-Base"
,
DownloadSource
.
MODELSCOPE
:
"ByteDance-Seed/Seed-OSS-36B-Base"
,
},
"Seed-OSS-36B-Base-woSyn"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-OSS-36B-Base-woSyn"
,
DownloadSource
.
MODELSCOPE
:
"ByteDance-Seed/Seed-OSS-36B-Base-woSyn"
,
},
"Seed-OSS-36B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ByteDance-Seed/Seed-OSS-36B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"ByteDance-Seed/Seed-OSS-36B-Instruct"
,
},
},
template
=
"seed_oss"
,
)
register_model_group
(
models
=
{
"Skywork-13B-Base"
:
{
...
...
@@ -3057,6 +3582,17 @@ register_model_group(
)
register_model_group
(
models
=
{
"VibeThinker-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"WeiboAI/VibeThinker-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"WeiboAI/VibeThinker-1.5B"
,
},
},
template
=
"qwen3"
,
)
register_model_group
(
models
=
{
"Vicuna-v1.5-7B-Chat"
:
{
...
...
src/llamafactory/extras/env.py
View file @
ca625f43
...
...
@@ -15,33 +15,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
platform
import
accelerate
import
datasets
import
peft
import
torch
import
transformers
import
trl
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
from
collections
import
OrderedDict
VERSION
=
"0.9.4
.dev0
"
VERSION
=
"0.9.4"
def
print_env
()
->
None
:
info
=
{
"`llamafactory` version"
:
VERSION
,
"Platform"
:
platform
.
platform
(),
"Python version"
:
platform
.
python_version
(),
"PyTorch version"
:
torch
.
__version__
,
"Transformers version"
:
transformers
.
__version__
,
"Datasets version"
:
datasets
.
__version__
,
"Accelerate version"
:
accelerate
.
__version__
,
"PEFT version"
:
peft
.
__version__
,
"TRL version"
:
trl
.
__version__
,
}
import
os
import
platform
import
accelerate
import
datasets
import
peft
import
torch
import
transformers
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
info
=
OrderedDict
(
{
"`llamafactory` version"
:
VERSION
,
"Platform"
:
platform
.
platform
(),
"Python version"
:
platform
.
python_version
(),
"PyTorch version"
:
torch
.
__version__
,
"Transformers version"
:
transformers
.
__version__
,
"Datasets version"
:
datasets
.
__version__
,
"Accelerate version"
:
accelerate
.
__version__
,
"PEFT version"
:
peft
.
__version__
,
}
)
if
is_torch_cuda_available
():
info
[
"PyTorch version"
]
+=
" (GPU)"
...
...
@@ -54,6 +57,13 @@ def print_env() -> None:
info
[
"NPU type"
]
=
torch
.
npu
.
get_device_name
()
info
[
"CANN version"
]
=
torch
.
version
.
cann
try
:
import
trl
# type: ignore
info
[
"TRL version"
]
=
trl
.
__version__
except
Exception
:
pass
try
:
import
deepspeed
# type: ignore
...
...
src/llamafactory/extras/logging.py
View file @
ca625f43
...
...
@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
library_root_logger
.
propagate
=
False
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
"_Logger"
:
def
get_logger
(
name
:
str
|
None
=
None
)
->
"_Logger"
:
r
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if
name
is
None
:
name
=
_get_library_name
()
...
...
src/llamafactory/extras/misc.py
View file @
ca625f43
...
...
@@ -18,7 +18,7 @@
import
gc
import
os
import
socket
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -94,11 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def
check_dependencies
()
->
None
:
r
"""Check the version of the required packages."""
check_version
(
"transformers>=4.
49
.0,<=4.5
2.4,!=4.52.0
"
)
check_version
(
"datasets>=2.16.0,<=
3.6
.0"
)
check_version
(
"accelerate>=1.3.0,<=1.
7
.0"
)
check_version
(
"peft>=0.14.0,<=0.1
5.2
"
)
check_version
(
"trl>=0.8.
6
,<=0.
9.6
"
)
check_version
(
"transformers>=4.
51
.0,<=4.5
7.1
"
)
check_version
(
"datasets>=2.16.0,<=
4.0
.0"
)
check_version
(
"accelerate>=1.3.0,<=1.
11
.0"
)
check_version
(
"peft>=0.14.0,<=0.1
7.1
"
)
check_version
(
"trl>=0.
1
8.
0
,<=0.
24.0
"
)
def
calculate_tps
(
dataset
:
list
[
dict
[
str
,
Any
]],
metrics
:
dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
...
...
@@ -211,9 +211,9 @@ def has_tokenized_data(path: "os.PathLike") -> bool:
return
os
.
path
.
isdir
(
path
)
and
len
(
os
.
listdir
(
path
))
>
0
def
infer_optim_dtype
(
model_dtype
:
"torch.dtype"
)
->
"torch.dtype"
:
def
infer_optim_dtype
(
model_dtype
:
Optional
[
"torch.dtype"
]
)
->
"torch.dtype"
:
r
"""Infer the optimal dtype according to the model_dtype and device compatibility."""
if
_is_bf16_available
and
model_dtype
==
torch
.
bfloat16
:
if
_is_bf16_available
and
(
model_dtype
==
torch
.
bfloat16
or
model_dtype
is
None
)
:
return
torch
.
bfloat16
elif
_is_fp16_available
:
return
torch
.
float16
...
...
@@ -313,6 +313,10 @@ def use_ray() -> bool:
return
is_env_enabled
(
"USE_RAY"
)
def
use_kt
()
->
bool
:
return
is_env_enabled
(
"USE_KT"
)
def
find_available_port
()
->
int
:
r
"""Find an available port on the local machine."""
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
...
...
@@ -328,3 +332,7 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
if
ipv6_enabled
:
os
.
environ
.
pop
(
"http_proxy"
,
None
)
os
.
environ
.
pop
(
"HTTP_PROXY"
,
None
)
os
.
environ
.
pop
(
"https_proxy"
,
None
)
os
.
environ
.
pop
(
"HTTPS_PROXY"
,
None
)
os
.
environ
.
pop
(
"all_proxy"
,
None
)
os
.
environ
.
pop
(
"ALL_PROXY"
,
None
)
src/llamafactory/extras/packages.py
View file @
ca625f43
...
...
@@ -58,6 +58,10 @@ def is_apollo_available():
return
_is_package_available
(
"apollo_torch"
)
def
is_jieba_available
():
return
_is_package_available
(
"jieba"
)
def
is_gradio_available
():
return
_is_package_available
(
"gradio"
)
...
...
@@ -66,6 +70,10 @@ def is_matplotlib_available():
return
_is_package_available
(
"matplotlib"
)
def
is_mcore_adapter_available
():
return
_is_package_available
(
"mcore_adapter"
)
def
is_pillow_available
():
return
_is_package_available
(
"PIL"
)
...
...
@@ -74,6 +82,10 @@ def is_ray_available():
return
_is_package_available
(
"ray"
)
def
is_kt_available
():
return
_is_package_available
(
"ktransformers"
)
def
is_requests_available
():
return
_is_package_available
(
"requests"
)
...
...
@@ -82,6 +94,14 @@ def is_rouge_available():
return
_is_package_available
(
"rouge_chinese"
)
def
is_safetensors_available
():
return
_is_package_available
(
"safetensors"
)
def
is_sglang_available
():
return
_is_package_available
(
"sglang"
)
def
is_starlette_available
():
return
_is_package_available
(
"sse_starlette"
)
...
...
@@ -91,13 +111,14 @@ def is_transformers_version_greater_than(content: str):
return
_get_package_version
(
"transformers"
)
>=
version
.
parse
(
content
)
@
lru_cache
def
is_torch_version_greater_than
(
content
:
str
):
return
_get_package_version
(
"torch"
)
>=
version
.
parse
(
content
)
def
is_uvicorn_available
():
return
_is_package_available
(
"uvicorn"
)
def
is_vllm_available
():
return
_is_package_available
(
"vllm"
)
def
is_sglang_available
():
return
_is_package_available
(
"sglang"
)
Prev
1
…
4
5
6
7
8
9
10
11
12
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment